我正在收集每批张量流中的一组汇总统计信息。
我想收集在测试集上计算的相同汇总统计信息,但测试集太大而无法在一批中处理。
在迭代测试集时,是否有方便的方法来计算相同的汇总统计信息?
答案 0 :(得分:4)
看起来它最近被添加了。我在contrib(以及后来的主线代码)中发现了这一点,流量度量评估。
https://www.tensorflow.org/api_docs/python/tf/metrics/mean
(根据评论更新链接)
答案 1 :(得分:1)
另一种可能性是在tensorflow之外的测试批次上累积摘要,并在图表中有一个虚拟变量,然后您可以为其分配累积结果。例如:假设您计算了多个批次的验证集上的损失,并希望得到平均值的摘要。您可以通过以下方式实现此目的:
with tf.name_scope('valid_loss'):
v_loss = tf.Variable(tf.constant(0.0), trainable=False)
self.v_loss_pl = tf.placeholder(tf.float32, shape=[], name='v_loss_pl')
self.update_v_loss = tf.assign(v_loss, self.v_loss_pl, name='update_v_loss')
with tf.name_scope('valid_summaries'):
v_loss_s = tf.summary.scalar('validation_loss', v_loss)
self.valid_summaries = tf.summary.merge([v_loss_s], name='valid_summaries')
然后在评估时间:
total_loss = 0.0
for batch in all_batches:
loss, _ = sess.run([get_loss, ...], feed_dict={...})
total_loss += loss
total_loss /= float(n_batches)
[_, v_summary_str] = sess.run([self.update_v_loss, self.valid_summaries],
feed_dict={self.v_loss_pl: total_loss})
writer.add_summary(v_summary_str)
虽然这可以完成工作,但它确实感觉有点黑客。来自您发布的contrib的流媒体指标评估可能会更优雅 - 我实际上从来没有遇到它,所以很好奇检查出来。