在Tensorflow中保存模型在GPU下不起作用?

时间:2017-08-20 14:15:50

标签: python tensorflow

更新:我发现使用tensorflow-cpu时下面的代码能正常工作。使用tensorflow-gpu时问题仍然存在。我怎样才能使它发挥作用?

我在代码中找不到问题 - 我正在尝试保存变量,然后重新加载它们,而且它们似乎不会从保存的模型中加载。

我会注意到,如果我在同一个python运行中进行保存和加载(没有进程结束并运行测试脚本),它们会加载。我的问题是,当我训练模式时,这不起作用 - >保存 - >流程结束 - >再次使用测试标志运行脚本 - >模型加载时没有错误,但结果就好像它不是。

代码:

运行#1

# creating LSTM model...

with tf.Session() as sess:
    saver = tf.train.Saver()

    # training...

    save_path = saver.save(sess, "./saved_models/model.ckpt")
    print("Model saved in file: %s" % save_path)

运行#2

# creating the same exact LSTM model...

with tf.Session() as sess:
    saver = tf.train.Saver()

    saver.restore(sess, "./saved_models/model.ckpt")
    print("Model restored.")

    # testing...

如果我背靠背地运行这两个片段,我会得到所需的输出 - 训练模型来预测一个简单的序列,并在测试期间正确预测它。如果我单独运行两个片段,模型会在测试期间预测错误的序列。

更新:我被建议尝试导入MetaGraph,但它也没有用。代码:

运行#1

# creating model...

tf.add_to_collection('a', net.a)
# adding nodes ...
tf.add_to_collection('z', net.z)

with tf.Session() as sess:
    saver = tf.train.Saver()
    # training...
    save_path = saver.save(sess, "./saved_models/my-model")
    print("Model saved in file: %s" % save_path)

运行#2

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('./saved_models/my-model.meta')
    new_saver.restore(sess, './saved_models/my-model')

    net.a = tf.get_collection('a')[0]
    # adding nodes ...
    net.z = tf.get_collection('z')[0]

    # testing...

上面的代码运行正常 - 但是测试集结果显示它不是后期训练(再次,如果我在同一个Python实例中运行两个片段,它可以正常工作)。

这应该是相当微不足道的,我无法让它发挥作用。欢迎任何帮助。具体来说,我并不是真的需要保存整个图形 - 只是变量(其中一些在LSTM单元格内)。

1 个答案:

答案 0 :(得分:1)

我遇到了同样的问题,我想你使用tf.Variable(),对吧? 尝试将其更改为tf.get_variable()。它对我有用:)