在TensorFlow培训期间打印丢失

时间:2015-11-20 18:41:16

标签: python tensorflow

我正在查看TensorFlow“MNIST For ML Beginners”教程,我想在每次训练后打印出训练损失。

我的训练循环目前看起来像这样:

for i in range(100):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

现在,train_step定义为:

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

cross_entropy是我要打印出来的损失:

cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

打印此方法的一种方法是在训练循环中明确计算cross_entropy

for i in range(100):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
    print 'loss = ' + str(cross_entropy)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

我现在有两个问题:

  1. 鉴于在cross_entropy期间已计算sess.run(train_step, ...),计算两次似乎效率低,需要两倍于所有训练数据的前向传球次数。有没有办法在cross_entropy期间计算sess.run(train_step, ...)时计算tf.Variable的价值?

  2. 如何打印str(cross_entropy)?使用<gfe:cache ...> <gfe:replicated-region id="categories" persistent="false"...> ... </gfe:replicated-region> 会给我一个错误......

  3. 谢谢!

2 个答案:

答案 0 :(得分:47)

您可以通过将cross_entropy添加到sess.run(...)的参数列表中来获取for的值。例如,您的for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) _, loss_val = sess.run([train_step, cross_entropy], feed_dict={x: batch_xs, y_: batch_ys}) print 'loss = ' + loss_val - 循环可以按如下方式重写:

cross_entropy

可以使用相同的方法来打印变量的当前值。比方说,除了tf.Variable的值之外,您还要打印W for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) _, loss_val, W_val = sess.run([train_step, cross_entropy, W], feed_dict={x: batch_xs, y_: batch_ys}) print 'loss = %s' % loss_val print 'W = %s' % W_val 的值,您可以执行以下操作:

record = Record(ID = 1,Duration = DurationField(timedelta(minutes=20)))

答案 1 :(得分:3)

不仅运行training_step,还要运行cross_entropy节点,以便将其值返回给您。请记住:

var_as_a_python_value = sess.run(tensorflow_variable)

会给你你想要的东西,所以你可以这样做:

[_, cross_entropy_py] = sess.run([train_step, cross_entropy],
                                 feed_dict={x: batch_xs, y_: batch_ys})

同时运行训练并拉出迭代期间计算的交叉熵值。请注意,我将sess.run的参数和返回值都转换为列表,以便两者都发生。