我的代码多少次从迭代器获取数据?

时间:2018-08-16 14:57:50

标签: python python-3.x tensorflow iterator dataset

我使用TFRecord管理数据集。

dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(...)
dataset = dataset.shuffle(...)
dataset = dataset.batch(...)
dataset = dataset.repeat(...)
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()

网络的输出:

logits_batch = network(image_batch)

我使用 tf.metrics 向我展示效果。

acc_value_op, acc_update_op = tf.metrics.accuracy(labels=label_batch, predictions=predict_batch, name="accuracy")

在tf.Session()中,我有以下代码:

_, loss_value, g_step, _, summary = sess.run(
    [train_op, loss_op, g_step_op, acc_update_op, summary_op],
    feed_dict={handle: train_iterator_handle})
acc_value = sess.run(
    [acc_value_op],
    feed_dict={handle: train_iterator_handle})

我将acc_update_op放在acc_value_op之前是因为我想先更新metrics.accuracy ,然后获取metrics.accuracy 的结果。

但是让我感到困惑的是

1)这两个sess.run(...)是否实际上会获取两批数据还是同一批数据?

2),我可以获取一批的最新acc值吗?

acc_value, _ = sess.run([acc_value_op, acc_update_op], feed_dict={.....})

1 个答案:

答案 0 :(得分:2)

数据集迭代器在两次运行之间保持一种状态,因此,每次调用run时,迭代器都会返回一个新的不同批次。如果希望它再次返回第一批,则必须初始化迭代器。

该行:

acc_value, _ = sess.run([acc_value_op, acc_update_op], feed_dict={.....})

将为您提供最新的累积精度值,它实际上等效于:

acc_value = sess.run(acc_update_op, feed_dict={.....})

由于acc_update_op的返回值与acc_value_op的返回值相同(请参见tf.metrics.accuracy)。两者之间的唯一区别是,运行第二个将更新内部指标变量,以便下次评估它时,它将反映累计值。请注意,您可以像这样操作op,将累计指标重置为零:

reset_metrics_op = tf.variables_initializer(tf.get_collection(METRIC_VARIABLES))

如果要同时具有批次精度值和累积精度值,则可以使用第二个指标:

batch_acc_value_op, _ = tf.metrics.accuracy(
    labels=label_batch, predictions=predict_batch, name="batch_accuracy")