在使用fit_generator()在数据生成器中应用model.predict()时出现的问题

时间:2019-05-16 23:28:10

标签: python tensorflow keras deep-learning

基本上,我正在实现一个模型,该模型使用感知损失来执行单幅图像超分辨率。我构建了完整的模型,这样输入将首先通过主模型,然后馈入经过预训练的VGG16,并将来自VGG16的layer [5]的输出作为完整模型的最终输出。 我试图将经过预训练的VGG16模型传递给我的数据生成器,以准备用于计算飞行中感知损失的地面真实图像。但是,在使用fit_generator进行训练期间,我遇到了价值问题。

我尝试编写自己的循环以为每个批次生成数据,并改用train_on_batch函数,并且工作正常。但是,我确实希望将use_multiprocessing与fit_generator结合使用。

这是我写的生成器。我将lossModel传递给生成器,并使用它生成具有感知损失的训练输出。

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, x_train, y_train, lossModel, batch_size=4, shuffle=True):
        'Initialization'
        self.x_train = x_train
        self.y_train = y_train
        self.lossModel = lossModel
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.x_train) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate  batch of data
        idx = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        x = self.x_train[idx,]
        y = self.lossModel.predict_on_batch(self.y_train[idx,])
        return x, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.x_train))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

在这里我构建模型。

### Create Image Transformation Model ###
mainModel = ResnetBuilder.build((3,72,72), 5, basic_block, [1, 1, 1, 1, 1])

### Create Loss Model (VGG16) ###
lossModel = VGG16(include_top=False, weights='imagenet', input_tensor=None, input_shape=(288,288,3))
lossModel.trainable=False
for layer in lossModel.layers:
    layer.trainable=False

### Create New Loss Model (Use Relu2-2 layer output for perceptual loss)
lossModel = Model(lossModel.inputs,lossModel.layers[5].output)
lossModelOutputs = lossModel(mainModel.output)

### Create Full Model ###
fullModel = Model(mainModel.input, lossModelOutputs)

### Compile FUll Model
fullModel.compile(loss='mse', optimizer='adam',metrics=['mse'])
trained_epochs=0

fit_generator()期间发生错误。请注意,我输入的尺寸为(72,72,3),VGG.layer [5]的输出位于(144,144,128)中,我的y_train是(288,288,3)中的地面真实图像。

# Generators
training_generator = DataGenerator(x_train, y_train, lossModel, batch_size=4, shuffle=True)
# Train model on dataset
fullModel.fit_generator(generator=training_generator, use_multiprocessing=True, workers=6)
Epoch 1/1
---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/utils/data_utils.py", line 401, in get_index
    return _SHARED_SEQUENCES[uid][i]
  File "/home/lucien/sr-perceptual/my_classes.py", line 26, in __getitem__
    y = self.lossModel.predict_on_batch(self.y_train[idx,])
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/engine/training.py", line 1273, in predict_on_batch
    self._make_predict_function()
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/engine/training.py", line 554, in _make_predict_function
    **kwargs)
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2744, in function
    return Function(inputs, outputs, updates=updates, **kwargs)
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2546, in __init__
    with tf.control_dependencies(self.outputs):
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 5004, in control_dependencies
    return get_default_graph().control_dependencies(control_inputs)
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4543, in control_dependencies
    c = self.as_graph_element(c)
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3490, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3569, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("block2_conv2/Relu:0", shape=(?, 144, 144, 128), dtype=float32) is not an element of this graph.
"""

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-10-4a040e0935cf> in <module>
      1 # Train model on dataset
----> 2 fullModel.fit_generator(generator=training_generator, use_multiprocessing=True, workers=6)

~/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

~/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1416             use_multiprocessing=use_multiprocessing,
   1417             shuffle=shuffle,
-> 1418             initial_epoch=initial_epoch)
   1419 
   1420     @interfaces.legacy_generator_methods_support

~/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    179             batch_index = 0
    180             while steps_done < steps_per_epoch:
--> 181                 generator_output = next(output_generator)
    182 
    183                 if not hasattr(generator_output, '__len__'):

~/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
    599         except Exception as e:
    600             self.stop()
--> 601             six.reraise(*sys.exc_info())
    602 
    603 

~/anaconda3/envs/fyp/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

~/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
    593         try:
    594             while self.is_running():
--> 595                 inputs = self.queue.get(block=True).get()
    596                 self.queue.task_done()
    597                 if inputs is not None:

~/anaconda3/envs/fyp/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
    642             return self._value
    643         else:
--> 644             raise self._value
    645 
    646     def _set(self, i, obj):

ValueError: Tensor Tensor("block2_conv2/Relu:0", shape=(?, 144, 144, 128), dtype=float32) is not an element of this graph.

1 个答案:

答案 0 :(得分:0)

这里的问题是多线程。当您呼叫6个工作程序时,将在图形终止后创建block2_conv2/Relu:0

问题出在_make_predict_function()上。您可以在PC中检查该文件的原因(我是从您的错误文本中得到的)File "/home/lucien/anaconda3/envs/fyp/lib/python3.6/site-packages/keras/engine/training.py",行1273,在Forecast_on_batch self._make_predict_function()中。

一些可以消除错误的方法是:

  • 使用theano后端。
  • 在加载经过训练的模型后立即致电model._make_predict_function()
  • 使用全局模型:

功能:

def load_model():
    global model
    model = yourmodel(weights=xx111122)
        # this is key : save the graph after loading the model
    global graph
    graph = tf.get_default_graph()

同时预测:

with graph.as_default():
   preds = model.predict(image)
   #... etc