更新:我发现使用tensorflow-cpu时下面的代码能正常工作。使用tensorflow-gpu时问题仍然存在。我怎样才能使它发挥作用?
我在代码中找不到问题 - 我正在尝试保存变量,然后重新加载它们,而且它们似乎不会从保存的模型中加载。
我会注意到,如果我在同一个python运行中进行保存和加载(没有进程结束并运行测试脚本),它们会加载。我的问题是,当我训练模式时,这不起作用 - >保存 - >流程结束 - >再次使用测试标志运行脚本 - >模型加载时没有错误,但结果就好像它不是。
代码:
运行#1
# creating LSTM model...
with tf.Session() as sess:
saver = tf.train.Saver()
# training...
save_path = saver.save(sess, "./saved_models/model.ckpt")
print("Model saved in file: %s" % save_path)
运行#2
# creating the same exact LSTM model...
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "./saved_models/model.ckpt")
print("Model restored.")
# testing...
如果我背靠背地运行这两个片段,我会得到所需的输出 - 训练模型来预测一个简单的序列,并在测试期间正确预测它。如果我单独运行两个片段,模型会在测试期间预测错误的序列。
更新:我被建议尝试导入MetaGraph,但它也没有用。代码:
运行#1
# creating model...
tf.add_to_collection('a', net.a)
# adding nodes ...
tf.add_to_collection('z', net.z)
with tf.Session() as sess:
saver = tf.train.Saver()
# training...
save_path = saver.save(sess, "./saved_models/my-model")
print("Model saved in file: %s" % save_path)
运行#2
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('./saved_models/my-model.meta')
new_saver.restore(sess, './saved_models/my-model')
net.a = tf.get_collection('a')[0]
# adding nodes ...
net.z = tf.get_collection('z')[0]
# testing...
上面的代码运行正常 - 但是测试集结果显示它不是后期训练(再次,如果我在同一个Python实例中运行两个片段,它可以正常工作)。
这应该是相当微不足道的,我无法让它发挥作用。欢迎任何帮助。具体来说,我并不是真的需要保存整个图形 - 只是变量(其中一些在LSTM单元格内)。
答案 0 :(得分:1)
我遇到了同样的问题,我想你使用tf.Variable()
,对吧?
尝试将其更改为tf.get_variable()
。它对我有用:)