如何在tensorflow

时间:2017-04-07 03:53:39

标签: tensorflow

我正在收集每批张量流中的一组汇总统计信息。

我想收集在测试集上计算的相同汇总统计信息,但测试集太大而无法在一批中处理。

在迭代测试集时,是否有方便的方法来计算相同的汇总统计信息?

2 个答案:

答案 0 :(得分:4)

看起来它最近被添加了。我在contrib(以及后来的主线代码)中发现了这一点,流量度量评估。

https://www.tensorflow.org/api_docs/python/tf/metrics/mean

(根据评论更新链接)

答案 1 :(得分:1)

另一种可能性是在tensorflow之外的测试批次上累积摘要,并在图表中有一个虚拟变量,然后您可以为其分配累积结果。例如:假设您计算了多个批次的验证集上的损失,并希望得到平均值的摘要。您可以通过以下方式实现此目的:

with tf.name_scope('valid_loss'):
    v_loss = tf.Variable(tf.constant(0.0), trainable=False)
    self.v_loss_pl = tf.placeholder(tf.float32, shape=[], name='v_loss_pl')
    self.update_v_loss = tf.assign(v_loss, self.v_loss_pl, name='update_v_loss')

with tf.name_scope('valid_summaries'):
    v_loss_s = tf.summary.scalar('validation_loss', v_loss)
    self.valid_summaries = tf.summary.merge([v_loss_s], name='valid_summaries')

然后在评估时间:

total_loss = 0.0
for batch in all_batches:
    loss, _ = sess.run([get_loss, ...], feed_dict={...})
    total_loss += loss
total_loss /= float(n_batches)

[_, v_summary_str] = sess.run([self.update_v_loss, self.valid_summaries],
                              feed_dict={self.v_loss_pl: total_loss})
writer.add_summary(v_summary_str)

虽然这可以完成工作,但它确实感觉有点黑客。来自您发布的contrib的流媒体指标评估可能会更优雅 - 我实际上从来没有遇到它,所以很好奇检查出来。