在keras模型中拟合一个简单的图像生成器

时间:2018-04-28 06:34:04

标签: keras generator

我有一个带有输入图像和标签值的keras模型。

我有一个数据生成器,可以读取图像,处理它并将其提供给网络

from PIL import Image

def my_iterator():
    i = 0
    while True:
        img_name = train_df.loc[i,'Image']
        img_label = train_df.loc[i,'Id']


        img = Image.open('master_train/'+str(img_name)).convert('L')
        print(img.mode)

        longer_side = max(img.size)
        horizontal_padding = (longer_side - img.size[0]) / 2
        vertical_padding = (longer_side - img.size[1]) / 2
        img = img.crop((-horizontal_padding,-vertical_padding,img.size[0] + horizontal_padding,img.size[1] + vertical_padding))
        img.thumbnail((128,128),Image.ANTIALIAS)

        img_array = np.asarray(img,dtype='uint8')
        img_array = img_array[:,:,np.newaxis]
        print(img_array.ndim)

        yield img_array,img_label
        i  = (i+1) % len(train_df)


from keras.models import Model
from keras.layers import Input,Dense

input_layer = Input(shape=(128,128,1))
x = Dense(100,activation='relu')(input_layer)
output_layer = Dense(1,activation='sigmoid')(x)

model = Model(inputs=input_layer,outputs=output_layer)
      model.compile(loss='binary_crossentropy',optimizer='nadam',metrics['accuracy'])
model.summary()

training_generator = my_iterator()

model.fit(training_generator,steps_per_epoch=1) 

我收到以下错误

AttributeError                            Traceback (most recent call last)
<ipython-input-189-7efa0828e76d> in <module>()
----> 1 model.fit(train_gen,steps_per_epoch=1)

~/work/venvs/keras3/lib/python3.6/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1628             sample_weight=sample_weight,
   1629             class_weight=class_weight,
-> 1630             batch_size=batch_size)
   1631         # Prepare validation data.
   1632         do_validation = False

~/work/venvs/keras3/lib/python3.6/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
   1474                                     self._feed_input_shapes,
   1475                                     check_batch_axis=False,
-> 1476                                     exception_prefix='input')
   1477         y = _standardize_input_data(y, self._feed_output_names,
   1478                                     output_shapes,

~/work/venvs/keras3/lib/python3.6/site-packages/keras/engine/training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
     74         data = data.values if data.__class__.__name__ == 'DataFrame' else data
     75         data = [data]
---> 76     data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data]
     77 
     78     if len(data) != len(names):

~/work/venvs/keras3/lib/python3.6/site-packages/keras/engine/training.py in <listcomp>(.0)
     74         data = data.values if data.__class__.__name__ == 'DataFrame' else data
     75         data = [data]
---> 76     data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data]
     77 
     78     if len(data) != len(names):

AttributeError: 'generator' object has no attribute 'ndim'

​

1 个答案:

答案 0 :(得分:0)

您应该使用fit_generator来使用生成器训练模型,而不是普通的fit函数。