我有一个训练有素的自动编码器,我想用它来减少图像尺寸,然后使用编码后的图像训练CNN。 如何使用编码后的图像训练CNN?我想使用Fit Generator来返回编码的图像以及相应的标签。
def custom_generator(generator):
for data, labels in generator:
data=encoder.predict(data)
yield data, labels
model.fit_generator(custom_generator(train_generator), steps_per_epoch=num_train_steps, epochs=25,validation_data=custom_generator(validation_generator), validation_steps=num_valid_steps)
这是我得到的错误:
Epoch 1/25
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-33-54d2c6697155> in <module>()
----> 1 model.fit_generator(custom_generator(train_generator), steps_per_epoch=num_train_steps, epochs=25,validation_data=custom_generator(validation_generator), validation_steps=num_valid_steps)
/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name +
90 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
/usr/local/lib/python3.6/dist-packages/keras/models.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1313 use_multiprocessing=use_multiprocessing,
1314 shuffle=shuffle,
-> 1315 initial_epoch=initial_epoch)
1316
1317 @interfaces.legacy_generator_methods_support
/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name +
90 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
2192 batch_index = 0
2193 while steps_done < steps_per_epoch:
-> 2194 generator_output = next(output_generator)
2195
2196 if not hasattr(generator_output, '__len__'):
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in get(self)
791 success, value = self.queue.get()
792 if not success:
--> 793 six.reraise(value.__class__, value, value.__traceback__)
/usr/local/lib/python3.6/dist-packages/six.py in reraise(tp, value, tb)
691 if value.__traceback__ is not tb:
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
695 value = None
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in _data_generator_task(self)
656 # => Serialize calls to
657 # infinite iterator/generator's next() function
--> 658 generator_output = next(self._generator)
659 self.queue.put((True, generator_output))
660 else:
<ipython-input-32-81cd29d5c219> in custom_generator(generator)
1 def custom_generator(generator):
----> 2 for data, labels in generator:
3 data=encoder.predict(data)
4 yield data, labels
ValueError: too many values to unpack (expected 2)