我有一个带有输入图像和标签值的keras模型。
我有一个数据生成器,可以读取图像,处理它并将其提供给网络
from PIL import Image
def my_iterator():
i = 0
while True:
img_name = train_df.loc[i,'Image']
img_label = train_df.loc[i,'Id']
img = Image.open('master_train/'+str(img_name)).convert('L')
print(img.mode)
longer_side = max(img.size)
horizontal_padding = (longer_side - img.size[0]) / 2
vertical_padding = (longer_side - img.size[1]) / 2
img = img.crop((-horizontal_padding,-vertical_padding,img.size[0] + horizontal_padding,img.size[1] + vertical_padding))
img.thumbnail((128,128),Image.ANTIALIAS)
img_array = np.asarray(img,dtype='uint8')
img_array = img_array[:,:,np.newaxis]
print(img_array.ndim)
yield img_array,img_label
i = (i+1) % len(train_df)
from keras.models import Model
from keras.layers import Input,Dense
input_layer = Input(shape=(128,128,1))
x = Dense(100,activation='relu')(input_layer)
output_layer = Dense(1,activation='sigmoid')(x)
model = Model(inputs=input_layer,outputs=output_layer)
model.compile(loss='binary_crossentropy',optimizer='nadam',metrics['accuracy'])
model.summary()
training_generator = my_iterator()
model.fit(training_generator,steps_per_epoch=1)
我收到以下错误
AttributeError Traceback (most recent call last)
<ipython-input-189-7efa0828e76d> in <module>()
----> 1 model.fit(train_gen,steps_per_epoch=1)
~/work/venvs/keras3/lib/python3.6/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1628 sample_weight=sample_weight,
1629 class_weight=class_weight,
-> 1630 batch_size=batch_size)
1631 # Prepare validation data.
1632 do_validation = False
~/work/venvs/keras3/lib/python3.6/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
1474 self._feed_input_shapes,
1475 check_batch_axis=False,
-> 1476 exception_prefix='input')
1477 y = _standardize_input_data(y, self._feed_output_names,
1478 output_shapes,
~/work/venvs/keras3/lib/python3.6/site-packages/keras/engine/training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
74 data = data.values if data.__class__.__name__ == 'DataFrame' else data
75 data = [data]
---> 76 data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data]
77
78 if len(data) != len(names):
~/work/venvs/keras3/lib/python3.6/site-packages/keras/engine/training.py in <listcomp>(.0)
74 data = data.values if data.__class__.__name__ == 'DataFrame' else data
75 data = [data]
---> 76 data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data]
77
78 if len(data) != len(names):
AttributeError: 'generator' object has no attribute 'ndim'
答案 0 :(得分:0)
您应该使用fit_generator
来使用生成器训练模型,而不是普通的fit
函数。