可重新初始化的迭代器,用于同步培训和验证

时间:2017-11-22 10:05:53

标签: python tensorflow

我想在培训期间使用DatasetIterator来评估验证集。我想时不时地评估一个(或几个)验证批次 - 时不时地通常一个时代。

然而,当重新启动切换输入时,可重新初始化的迭代器会重新开始。 E.g。

import tensorflow as tf

dataset_trn = tf.data.Dataset.range(10)
dataset_tst = tf.data.Dataset.range(10).map(lambda i: i + 1000)
iterator = tf.data.Iterator.from_structure(dataset_trn.output_types,
    dataset_trn.output_shapes)
batch = iterator.get_next()
trn_init_op = iterator.make_initializer(dataset_trn)
tst_init_op = iterator.make_initializer(dataset_tst)

sess = tf.InteractiveSession()

for _ in range(2):
  sess.run(trn_init_op)
  for _ in range(5):
    print(batch.eval())

  sess.run(tst_init_op)
  print(batch.eval())

返回

0
1
2
3
4
1000
0
1
2
3
4
1000

但是我想继续这样的训练:

0
1
2
3
4
1000
5
6
7
8
9
1001

有没有办法实现这个目标?请注意,在实践中,批处理是混乱的,我希望它在相同的伪随机点恢复。

1 个答案:

答案 0 :(得分:2)

Feedable iterators应该有所帮助,但他们很难与之合作。您需要创建一个占位符和字符串句柄:

dataset_trn = tf.data.Dataset.range(10)
dataset_tst = tf.data.Dataset.range(10).map(lambda i: i + 1000)

holder = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    holder, dataset_trn.output_types, dataset_trn.output_shapes)
batch = iterator.get_next()

trn_iter = dataset_trn.make_one_shot_iterator()
trn_handle = trn_iter.string_handle()

tst_iter = dataset_tst.make_one_shot_iterator()
tst_handle = tst_iter.string_handle()

with tf.Session() as sess:
  for _ in range(2):

    trn_string = sess.run(trn_handle)
    tst_string = sess.run(tst_handle)

    for _ in range(5):
      print(sess.run(batch, feed_dict={holder: trn_string}))

    print(sess.run(batch, feed_dict={holder: tst_string}))