我试图循环训练多个Keras模型以评估不同的参数。为避免内存问题,在每次构建模型之前,我都致电K.clear_session()
。
添加了K.clear_session()
调用后,保存第二个模型时我开始收到此错误。
引发ValueError(“张量%s不是此图的元素。”%obj) ValueError:Tensor Tensor(“ level1 / kernel:0”,shape =(3,3,3,16),dtype = float32_ref)不是此图的元素。 在处理上述异常期间,发生了另一个异常:
回溯(最近通话最近): 在第286行的“ /home/gus/workspaces/wpy/cnn/srs/train_generators.py”文件中 train_models(model_defs) 在train_models中的文件“ /home/gus/workspaces/wpy/cnn/srs/train_generators.py”中,第196行 model.save(file_path) 保存中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/engine/network.py”,行1090 save_model(自身,文件路径,覆盖,include_optimizer) 在save_model中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/engine/saving.py”,行382 _serialize_model(model,f,include_optimizer) _serialize_model中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/engine/saving.py”,第97行 weight_values = K.batch_get_value(符号权重) 文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py”,第2420行,位于batch_get_value中 返回get_session()。run(ops) 运行中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,第929行 run_metadata_ptr) _run中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,行1137 self._graph,提取,feed_dict_tensor,feed_handles = feed_handles) init 中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,第471行 self._fetch_mapper = _FetchMapper.for_fetch(获取) 在for_fetch中的文件261行中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py” 返回_ListFetchMapper(fetch) init 中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,第370行 self._mappers = [_FetchMapper.for_fetch(fetch)用于以抓取方式进行抓取] 在第370行中输入“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py” self._mappers = [_FetchMapper.for_fetch(fetch)用于以抓取方式进行抓取] 在for_fetch中,文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py”,第271行 返回_ElementFetchMapper(fetches,contraction_fn) init 中的第307行中的文件“ /home/gus/workspaces/venvs/dlcv/lib/python3.6/site-packages/tensorflow/python/client/session.py” '张量。 (%s)'%(获取,str(e))) ValueError:提取参数不能解释为张量。 (Tensor Tensor(“ level1 / kernel:0”,shape =(3,3,3,16),dtype = float32_ref)不是此图的元素。)
代码基本上是:
while <models to train>:
K.clear_session()
model = modeldef.build() # everything that has a tensor goes here and just here
# create generators from directories
opt = Adam(lr=0.001, decay=0.001 / epochs)
model.compile(...)
H = model.fit_generator(...)
model.save(file_path) # --> here it crashes
无论网络有多深,如此简单的ConvNet都会使代码在保存时失败:
class SuperSimpleCNN:
def __init__(self, img_size, depth):
self.img_size = img_size
self.depth = depth
def build(self):
init = Input(shape=(self.img_size, self.img_size, self.depth))
x = Convolution2D(16, (3, 3), padding='same', name='level1')(init)
x = Activation('relu')(x)
out = Convolution2D(self.depth, (5, 5), padding='same', name='output')(x)
model = Model(init, out)
return model
看着类似的问题,我知道这个问题是由于keras共享一个全局会话,并且来自不同模型的不同图无法混合。
但是我不明白为什么在每个模型中使用K.clear_session()
会使迭代> 1时保存操作失败。以及为什么Tensor和Variable之间的区别。
<< strong> tf.Variable 'level1 / kernel:0'shape =(3,3,3,16)dtype = float32_ref>不能解释为 Tensor < / p>
有人可以帮忙吗?
谢谢。
答案 0 :(得分:0)
我的错误,我导入了错误的软件包:
从tensorflow.python.keras导入后端为K
代替
将keras.backend导入为K