自动编码器训练后使用编码器模型进行预测时出错

时间:2020-05-13 21:55:16

标签: python-3.x autoencoder

在自动编码器模型训练完成后,我只想使用编码器模型来提取使用编码器模型的编码。训练完成后,在使用编码器模型进行预测时遇到问题。我使用以下代码来训练自动编码器模型

x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(16, (3, 3), activation='relu',padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu',padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu',padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(img_channel, (3, 3), activation='sigmoid', padding='same')(x) # example from documentaton

autoencoder = Model(input_img, decoded)
autoencoder.summary() # show model data

# create an encoder for debugging purposes later
encoder = Model(input_img, encoded)
# encoder = Model(autoencoder.input,autoencoder.layers[-10].output)

autoencoder.compile(optimizer='sgd',loss='mean_squared_error',metrics=[metrics.mae, metrics.categorical_accuracy])

# do not run forever but stop if model does not get better
stopper = EarlyStopping(monitor='mean_absolute_error', min_delta=0.0001, patience=5, mode='auto', verbose=1)

# do the actual fitting
autoencoder_train = autoencoder.fit_generator(
        train_generator,
        #validation_data=validation_generator,
        epochs=epochs,
        shuffle=False,
        callbacks=[stopper])
# Find out encoding of train data
train_encoding = encoder.predict(train_generator)

自动编码器训练完成后,在使用编码器模型提取图像编码时出现以下错误

~/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, max_queue_size, workers, use_multiprocessing)
   1838           max_queue_size=max_queue_size,
   1839           workers=workers,
-> 1840           use_multiprocessing=use_multiprocessing)
   1841
   1842     # Backwards compatibility.

~/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in predict_generator(self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose)
   2296         workers=workers,
   2297         use_multiprocessing=use_multiprocessing,
-> 2298         verbose=verbose)
   2299
   2300   def _get_callback_model(self):

~/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py in predict_generator(model, generator, steps, max_queue_size, workers, use_multiprocessing, verbose)
    352   """See docstring for `Model.predict_generator`."""
    353   if not context.executing_eagerly():
--> 354     model._make_test_function()
    355
    356   steps_done = 0

~/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _make_test_function(self)
    714   def _make_test_function(self):
    715     if not hasattr(self, 'test_function'):
--> 716       raise RuntimeError('You must compile your model before using it.')
    717     if self.test_function is None:
    718       inputs = (self._feed_inputs +

RuntimeError: You must compile your model before using it.

这有什么问题吗?如果有人帮助我,那就太好了。

1 个答案:

答案 0 :(得分:0)

这似乎是tensorflow的版本问题。发生此错误时,我曾使用过tensorflow 2.2。当我将其降级到1.12时,错误已解决。

相关问题