我希望在训练期间每100次迭代保存预测并调用到tensorboard,我需要计算训练和测试数据集的指标。这是我使用的代码:
precision, _ = tf.metrics.precision(labels_placeholder, mypredictions,
metrics_collections = ['metrics'], updates_collections = ['update_op'])
tf.summary.scalar('train_precision', precision, collections = ['train_metrics'])
tf.summary.scalar('test_precision', precision, collections = ['test_metrics'])
# the metrics ops have local variables that need to be
# init'd each time.
sess.run(tf.local_variables_initializer())
for batch in range(100):
train_images, train_labels = train_dataset.next_batch()
feed_dict = {images_placeholder: train_images, labels_placeholder: train_labels}
sess.run(tf.get_collection('update_op'), feed_dict = feed_dict)
sess.run(tf.get_collection('metrics'), feed_dict = feed_dict)
summary_str = sess.run(train_metrics_summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, global_step_val)
sess.run(tf.local_variables_initializer())
for batch in range(100):
test_images, test_labels = test_dataset.next_batch()
feed_dict = {images_placeholder: test_images, labels_placeholder: test_labels}
sess.run(tf.get_collection('update_op'), feed_dict = feed_dict)
sess.run(tf.get_collection('metrics'), feed_dict = feed_dict)
summary_str = sess.run(test_metrics_summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, global_step_val)
请注意,要在列车和测试数据上运行并获取相同变量的摘要,我会定义两个用于训练和测试的摘要操作,将它们放在不同的集合中,并分别为列车和测试数据运行两个摘要操作。
我还认为我应该重新初始化update_op的局部变量,因为official doc说
它将sum和count变量设置为零。
我的问题:这是在火车和测试仪上运行操作系统的最佳方法吗?
TF文件也提到了
请注意,在不同输入上多次评估相同指标时,必须指定每个指标的范围,以避免将结果累积在一起:
labels = ...
predictions0 = ...
predictions1 = ...
accuracy0 = tf.contrib.metrics.accuracy(labels, predictions0, name='preds0')
accuracy1 = tf.contrib.metrics.accuracy(labels, predictions1, name='preds1')
但这似乎与我的方法不同,因为我只用一个precision
变量定义了一个变量myprediction
。而且似乎我不需要担心这个问题,也不需要添加name
选项?
答案 0 :(得分:0)
使用tf.metrics
函数时要记住的最重要的事情是,这里的所有满足标准都是运行指标,是在多次运行中计算运行指标。有关详细信息,请参见本文:http://ronny.rest/blog/post_2017_09_11_tf_metrics/
因此,这里您要在训练和测试数据上计算累积精度,因为update_op的精度都相同,并且训练和测试的数据集都在更新相同的运行精度。