是否有一种方法可以评估依赖于tf.data迭代器的张量,但可以暂时暂停该迭代器,以便返回前一批?
想象下面的代码段:
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
train_op = next_batch * 10
每次我评估train_op
时,它都是通过获取新一批数据来实现的,这正是我想要的。但是,每隔N个步骤,我想做一些调试工作,例如评估训练批的准确性,创建检查点,在禁用辍学的情况下运行等。我希望这些操作能在我刚刚使用的同一数据批上进行使用过,但我还没有找到暂停 tf.data
迭代器一个或多个步骤的方法。
显而易见的解决方案是使用占位符,而不是直接使用next_batch
。这意味着我必须首先评估next_batch
,然后使用feed_dict
将其反馈给会话以评估train_op
。我认为不建议这样做,因为会降低性能。还是这样吗?如果是这样,建议如何处理这些案件?
编辑:添加我要使用的伪代码:
for step in num_steps:
sess.run(train_op) # train_op depends on next_batch and therefore fetches new batch
if step % N == 0:
# I want below to run on the same batch above but acc_op also
# depends on next_batch and therefore fetches a new batch
acc = sess.run([acc_op, saver_op, feed_dic={keep_drop:1}])
答案 0 :(得分:-1)
它不能按以下方式工作吗
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
train_op = next_batch * 10
other_ops = do_other_stuff(next_batch)
num_train_batch = 50
for ep in range(num_train_batch):
if ep%N==0:
_, other_stuffs = sess.run([train_op, other_ops])
else:
_ = ses.run(train_op)
,并且您可以每次以不同的方式输入辍学内容