循环训练keras模型:在进一步调用K.clear_session()

时间:2019-04-18 08:47:16

标签: tensorflow keras

我试图循环训练多个Keras模型以评估不同的参数。为避免内存问题,在每次构建模型之前,我都致电K.clear_session()

添加了K.clear_session()调用后,保存第二个模型时我开始收到此错误。

  

引发ValueError(“张量%s不是此图的元素。”%obj)   ValueError:Tensor Tensor(“ level1 / kernel:0”,shape =(3,3,3,16),dtype = float32_ref)不是此图的元素。   在处理上述异常期间,发生了另一个异常:

     

回溯(最近通话最近):     在第286行的“ /home/gus/workspaces/wpy/cnn/srs/train_generators.py”文件中       train_models(model_defs)     在train_models中的文件“ /home/gus/workspaces/wpy/cnn/srs/train_generators.py”中,第196行       model.save(file_path)     保存中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/engine/network.py”,行1090       save_model(自身,文件路径,覆盖,include_optimizer)     在save_model中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/engine/saving.py”,行382       _serialize_model(model,f,include_optimizer)     _serialize_model中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/engine/saving.py”,第97行       weight_values = K.batch_get_value(符号权重)     文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py”,第2420行,位于batch_get_value中       返回get_session()。run(ops)     运行中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,第929行       run_metadata_ptr)     _run中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,行1137       self._graph,提取,feed_dict_tensor,feed_handles = feed_handles)      init 中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,第471行       self._fetch_mapper = _FetchMapper.for_fetch(获取)     在for_fetch中的文件261行中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”       返回_ListFetchMapper(fetch)      init 中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,第370行       self._mappers = [_FetchMapper.for_fetch(fetch)用于以抓取方式进行抓取]     在第370行中输入“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”       self._mappers = [_FetchMapper.for_fetch(fetch)用于以抓取方式进行抓取]     在for_fetch中,文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,第271行       返回_ElementFetchMapper(fetches,contraction_fn)      init 中的第307行中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”       '张量。 (%s)'%(获取,str(e)))    ValueError:提取参数不能解释为张量。 (Tensor Tensor(“ level1 / kernel:0”,shape =(3,3,3,16),dtype = float32_ref)不是此图的元素。)

代码基本上是:

while <models to train>:
    K.clear_session()
    model = modeldef.build() # everything that has a tensor goes here and just here
    # create generators from directories

    opt = Adam(lr=0.001, decay=0.001 / epochs)
    model.compile(...)
    H = model.fit_generator(...)

    model.save(file_path) # --> here it crashes

无论网络有多深,如此简单的ConvNet都会使代码在保存时失败:

class SuperSimpleCNN:
    def __init__(self, img_size, depth):
        self.img_size = img_size
        self.depth = depth

    def build(self):
        init = Input(shape=(self.img_size, self.img_size, self.depth))

        x = Convolution2D(16, (3, 3), padding='same', name='level1')(init)
        x = Activation('relu')(x)

        out = Convolution2D(self.depth, (5, 5), padding='same', name='output')(x)
        model = Model(init, out)
        return model

看着类似的问题,我知道这个问题是由于keras共享一个全局会话,并且来自不同模型的不同图无法混合。 但是我不明白为什么在每个模型中使用K.clear_session()会使迭代> 1时保存操作失败。以及为什么Tensor和Variable之间的区别。

  

<< strong> tf.Variable 'level1 / kernel:0'shape =(3,3,3,16)dtype = float32_ref>不能解释为 Tensor < / p>

有人可以帮忙吗?

谢谢。

1 个答案:

答案 0 :(得分:0)

我的错误,我导入了错误的软件包:

  

从tensorflow.python.keras导入后端为K

代替

  

将keras.backend导入为K