我想在培训期间使用Dataset
和Iterator
来评估验证集。我想时不时地评估一个(或几个)验证批次 - 时不时地通常不一个时代。
然而,当重新启动切换输入时,可重新初始化的迭代器会重新开始。 E.g。
import tensorflow as tf
dataset_trn = tf.data.Dataset.range(10)
dataset_tst = tf.data.Dataset.range(10).map(lambda i: i + 1000)
iterator = tf.data.Iterator.from_structure(dataset_trn.output_types,
dataset_trn.output_shapes)
batch = iterator.get_next()
trn_init_op = iterator.make_initializer(dataset_trn)
tst_init_op = iterator.make_initializer(dataset_tst)
sess = tf.InteractiveSession()
for _ in range(2):
sess.run(trn_init_op)
for _ in range(5):
print(batch.eval())
sess.run(tst_init_op)
print(batch.eval())
返回
0
1
2
3
4
1000
0
1
2
3
4
1000
但是我想继续这样的训练:
0
1
2
3
4
1000
5
6
7
8
9
1001
有没有办法实现这个目标?请注意,在实践中,批处理是混乱的,我希望它在相同的伪随机点恢复。
答案 0 :(得分:2)
Feedable iterators应该有所帮助,但他们很难与之合作。您需要创建一个占位符和字符串句柄:
dataset_trn = tf.data.Dataset.range(10)
dataset_tst = tf.data.Dataset.range(10).map(lambda i: i + 1000)
holder = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
holder, dataset_trn.output_types, dataset_trn.output_shapes)
batch = iterator.get_next()
trn_iter = dataset_trn.make_one_shot_iterator()
trn_handle = trn_iter.string_handle()
tst_iter = dataset_tst.make_one_shot_iterator()
tst_handle = tst_iter.string_handle()
with tf.Session() as sess:
for _ in range(2):
trn_string = sess.run(trn_handle)
tst_string = sess.run(tst_handle)
for _ in range(5):
print(sess.run(batch, feed_dict={holder: trn_string}))
print(sess.run(batch, feed_dict={holder: tst_string}))