自动编码器后训练CNN

时间:2018-09-20 21:54:55

标签: python tensorflow machine-learning keras conv-neural-network

我有一个训练有素的自动编码器,我想用它来减少图像尺寸,然后使用编码后的图像训练CNN。 如何使用编码后的图像训练CNN?我想使用Fit Generator来返回编码的图像以及相应的标签。

def custom_generator(generator):
    for data, labels in generator:
        data=encoder.predict(data)
        yield data, labels
model.fit_generator(custom_generator(train_generator), steps_per_epoch=num_train_steps, epochs=25,validation_data=custom_generator(validation_generator), validation_steps=num_valid_steps)

这是我得到的错误:

     Epoch 1/25
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    <ipython-input-33-54d2c6697155> in <module>()
    ----> 1 model.fit_generator(custom_generator(train_generator), steps_per_epoch=num_train_steps, epochs=25,validation_data=custom_generator(validation_generator), validation_steps=num_valid_steps)

    /usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
         89                 warnings.warn('Update your `' + object_name +
         90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
    ---> 91             return func(*args, **kwargs)
         92         wrapper._original_function = func
         93         return wrapper

    /usr/local/lib/python3.6/dist-packages/keras/models.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
       1313                                         use_multiprocessing=use_multiprocessing,
       1314                                         shuffle=shuffle,
    -> 1315                                         initial_epoch=initial_epoch)
  1316 
   1317     @interfaces.legacy_generator_methods_support

/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2192                 batch_index = 0
   2193                 while steps_done < steps_per_epoch:
-> 2194                     generator_output = next(output_generator)
   2195 
   2196                     if not hasattr(generator_output, '__len__'):

/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in get(self)
    791             success, value = self.queue.get()
    792             if not success:
--> 793                 six.reraise(value.__class__, value, value.__traceback__)

/usr/local/lib/python3.6/dist-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in _data_generator_task(self)
    656                             # => Serialize calls to
    657                             # infinite iterator/generator's next() function
--> 658                             generator_output = next(self._generator)
    659                             self.queue.put((True, generator_output))
    660                         else:

<ipython-input-32-81cd29d5c219> in custom_generator(generator)
      1 def custom_generator(generator):
----> 2   for data, labels in generator:
      3     data=encoder.predict(data)
      4     yield data, labels

ValueError: too many values to unpack (expected 2)

0 个答案:

没有答案