如何使用TensorFlow的`tf.data` API多次获取相同的数据批处理

时间:2018-08-23 17:15:50

标签: tensorflow tensorflow-datasets

是否有一种方法可以评估依赖于tf.data迭代器的张量,但可以暂时暂停该迭代器,以便返回前一批?

想象下面的代码段:

dataset = tf.data.Dataset.range(5)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
train_op = next_batch * 10

每次我评估train_op时,它都是通过获取新一批数据来实现的,这正是我想要的。但是,每隔N个步骤,我想做一些调试工作,例如评估训练批的准确性,创建检查点,在禁用辍学的情况下运行等。我希望这些操作能在我刚刚使用的同一数据批上进行使用过,但我还没有找到暂停 tf.data迭代器一个或多个步骤的方法。

显而易见的解决方案是使用占位符,而不是直接使用next_batch。这意味着我必须首先评估next_batch,然后使用feed_dict将其反馈给会话以评估train_op。我认为不建议这样做,因为会降低性能。还是这样吗?如果是这样,建议如何处理这些案件?

编辑:添加我要使用的伪代码:

for step in num_steps:

    sess.run(train_op) # train_op depends on next_batch and therefore fetches new batch

    if step % N == 0:
        # I want below to run on the same batch above but acc_op also
        # depends on next_batch and therefore fetches a new batch
        acc = sess.run([acc_op, saver_op, feed_dic={keep_drop:1}]) 

1 个答案:

答案 0 :(得分:-1)

它不能按以下方式工作吗

dataset = tf.data.Dataset.range(5)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
train_op = next_batch * 10
other_ops = do_other_stuff(next_batch)

num_train_batch = 50
for ep in range(num_train_batch):
   if ep%N==0:
      _, other_stuffs = sess.run([train_op, other_ops])
   else:
      _ = ses.run(train_op)

,并且您可以每次以不同的方式输入辍学内容