keras:fit_generator错误

时间:2018-07-01 14:41:58

标签: tensorflow keras generator

我正在使用Keras学习CNN。我试图创建一个对象检测模型。

型号:

model=MobileNet(weights='imagenet', include_top=True, input_shape=(img_width, img_height, 3))
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

对于我正在使用的数据flow_from_directory

datagen = ImageDataGenerator(featurewise_center=True,
                             featurewise_std_normalization=True,
                             width_shift_range=0.2, height_shift_range=0.2)

train_generator = datagen.flow_from_directory(
                       train_data_dir,
                        target_size=(img_width, img_height),
                       ...)

datagen = validation_datagen = ImageDataGenerator(rescale=1. / 255)

validation_generator = datagen.flow_from_directory(
                       valid_data_dir,
                        target_size=(img_width, img_height),
                        ...)

当我使用model.fit_generator时:

model.fit_generator(
                train_generator,
                steps_per_epoch=train_samples // batch_size,
                epochs=epochs,
                validation_data=validation_generator,
                validation_steps=valid_samples // batch_size,
                verbose=1)

我收到以下错误。我无法调试问题。

<ipython-input-10-fa0bcaa5b96b> in <module>()
      5                 validation_steps=valid_samples // batch_size,
      6                 workers=5,
----> 7                 verbose=1)

E:\Anaconda3\envs\env\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

E:\Anaconda3\envs\env\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1424             use_multiprocessing=use_multiprocessing,
   1425             shuffle=shuffle,
-> 1426             initial_epoch=initial_epoch)
   1427 
   1428     @interfaces.legacy_generator_methods_support

E:\Anaconda3\envs\env\lib\site-packages\keras\engine\training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    189                 outs = model.train_on_batch(x, y,
    190                                             sample_weight=sample_weight,
--> 191                                             class_weight=class_weight)
    192 
    193                 if not isinstance(outs, list):

E:\Anaconda3\envs\env\lib\site-packages\keras\engine\training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1212             x, y,
   1213             sample_weight=sample_weight,
-> 1214             class_weight=class_weight)
   1215         if self._uses_dynamic_learning_phase():
   1216             ins = x + y + sample_weights + [1.]

E:\Anaconda3\envs\env\lib\site-packages\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
    790                 feed_output_shapes,
    791                 check_batch_axis=False,  # Don't enforce the batch size.
--> 792                 exception_prefix='target')
    793 
    794             # Generate sample-wise weight values given the `sample_weight` and

E:\Anaconda3\envs\env\lib\site-packages\keras\engine\training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    134                             ': expected ' + names[i] + ' to have shape ' +
    135                             str(shape) + ' but got array with shape ' +
--> 136                             str(data_shape))
    137     return data
    138 

ValueError: Error when checking target: expected reshape_2 to have shape (1000,) but got array with shape (1,)

0 个答案:

没有答案