如何从保存的元图恢复我的损失?

时间:2017-06-12 23:15:36

标签: tensorflow scope restore

我已经构建了一个工作正常的简单张量流模型。 训练时我会在不同的步骤中保存meta_graph和一些参数。

之后(在新脚本中)我想恢复保存的meta_graph并恢复变量和操作。

一切正常,但只有

with tf.name_scope('MSE'):
    error = tf.losses.mean_squared_error(Y, yhat, scope="error")

不会被恢复。使用以下行

mse_error = graph.get_tensor_by_name("MSE/error:0")
  

"名称' MSE /错误:0'是指不存在的张量。该   操作,' MSE /错误',图中不存在。"

出现此错误消息。

由于我对其他变量和操作执行完全相同的过程而没有任何错误,我不知道如何处理。唯一的区别是tf.losses.mean_squared_error函数中只有一个scope属性而不是name属性。

那么如何使用范围恢复损失操作?

这里是我保存和加载模型的代码。

保存:

# define network ...
saver = tf.train.Saver(max_to_keep=10)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(NUM_EPOCHS):
    # do training ..., save model all 1000 optimization steps
    if (i + 1) % 1000 == 0:
        saver.save(sess, "L:/model/mlp_model", global_step=(i+1))

还原:

# start a session
sess=tf.Session()
# load meta graph
saver = tf.train.import_meta_graph('L:\\model\\mlp_model-1000.meta')
# restore weights
saver.restore(sess, tf.train.latest_checkpoint('L:\\model\\'))

# access network nodes
graph = tf.get_default_graph()
X = graph.get_tensor_by_name("Input/X:0")
Y = graph.get_tensor_by_name("Input/Y:0")

# restore output-generating operation used for prediction
yhat_op = graph.get_tensor_by_name("OutputLayer/yhat:0")
mse_error = graph.get_tensor_by_name("MSE/error:0") # this one doesn't work

2 个答案:

答案 0 :(得分:4)

为了让您的训练退一步,documentation建议您将其添加到集合中,然后再将其保存为在恢复图表后能够指向它的方式。

保存:

saver = tf.train.Saver(max_to_keep=10)
# put op in collection
tf.add_to_collection('train_op', train_op)
...

还原:

saver = tf.train.import_meta_graph('L:\\model\\mlp_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('L:\\model\\'))
# recover op through collection
train_op = tf.get_collection('train_op')[0]

为什么您尝试按名称恢复张量失败?

你的名字确实可以得到张量 - 问题是你需要正确的名字。请注意,error tf.losses.mean_squared_error的{​​{1}}参数是范围名称,而不是返回操作的名称。这可能令人困惑,因为其他操作(例如tf.nn.l2_loss)接受name参数。

最后,error操作的名称为MSE/error/value:0,您可以使用该名称来获取名称。

也就是说,直到你更新tensorflow时再次破坏。 tf.losses.mean_squared_error并未对其输出名称提供任何保证,因此很可能因某些原因而改变。

我认为这是推动收藏品使用的动力:缺乏对自己无法控制的经营者名称的保证。

或者,如果由于某种原因你真的想使用名字,你可以像这样重命名你的运算符:

with tf.name_scope('MSE'):
  error = tf.losses.mean_squared_error(Y, yhat, scope='error')
  # let me stick my own name on it
  error = tf.identity(error, 'my_error')

然后你可以安全地依赖graph.get_tensor_by_name('MSE/my_error:0')

答案 1 :(得分:1)

tf.losses.mean_squared_error是不是张量的操作,您应该使用 get_operation_by_name

mse_error = graph.get_operation_by_name("MSE/error")

应该有效,请注意,不需要“:0”