如何在Tensorflow中设置损失操作的名称?

时间:2017-09-16 20:27:16

标签: python tensorflow

在Tensorflow中,我可以为操作和张量指定名称,以便以后检索它们。例如,我可以做一个功能

input_layer=tf.placeholder(tf.float32, shape= [None,300], name='input_layer')

然后在另一个函数中,我可以做

input_layer=get_tensor_by_name('input_layer:0')

我开始相信这对于使我的tf代码尽可能模块化非常方便。

我希望能够对我的损失做同样的事情但是如何为该操作分配自定义名称?问题是损失函数的构建(例如,tf.losses.mean_squared_error)没有名称的参数(与tf.placeholder,tf.variable等相反)。

我现在谈到我的损失的方式是

tf.get_collection(tf.GraphKeys.LOSSES)[-1]

(检索已添加到图表中的最后一次丢失操作)。我错过了一些明显的东西吗?

3 个答案:

答案 0 :(得分:4)

我知道这不是答案,但它是一个可能适合你的修复。

鉴于此,正如您所指出的,tf.losses.mean_squared_error函数没有name参数,您可以实现自己的MSE(当然,基于TF操作)

只需替换

tf_loss = tf.losses.mean_squared_error(labels,predictions)

使用

custom_loss = tf.reduce_mean(tf.squared_difference(labels,predictions),name='loss')

由于reduce_mean接受name参数,您可以获得所需内容。

完整的示例代码here

答案 1 :(得分:3)

我想使用不可训练的变量应该可以做到这一点:

labels = np.random.normal(size=10)
predictions = np.random.normal(size=10)

sess = tf.Session()
loss_var = tf.Variable(10.0, name='mse_loss', trainable=False, dtype=tf.float32)

loss = tf.losses.mean_squared_error(labels, predictions)
mse_loss = loss_var.assign(loss)

sess.run(mse_loss)
print(sess.run(tf.get_default_graph().get_tensor_by_name('mse_loss:0')))

答案 2 :(得分:1)

我发现最简单的方法是使用tf.identity

loss = tf.losses.mean_squared_error(labels, predictions)
loss = tf.identity(loss, name = "loss")