张量流可重新初始化的迭代器问题

时间:2019-03-08 18:15:41

标签: tensorflow

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_dataset = train_dataset.batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.batch(1000)

iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
                                       train_dataset.output_shapes)

next_element_x, next_element_y = iterator.get_next()

training_init_op = iterator.make_initializer(train_dataset)
testing_init_op = iterator.make_initializer(test_dataset)

logits = DenseNet(x=next_element_x, nb_blocks=nb_block, filters=growth_k, training=training_flag).model
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=next_element_y, logits=logits))

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(next_element_y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

我正在尝试使用输入管道来提高代码的性能。我使用可重新初始化的迭代器来实现它。

在我的代码中,由next_element_y生成的iterator.get_next()在每次迭代中都要使用两次(成本和correct_prediction)。 因此,我有782次迭代/时期,并且在391次迭代后出现了超出范围的错误。

如何两次使用next_element_y而不在一次迭代中额外触发迭代器一次?

2 个答案:

答案 0 :(得分:0)

您可以只使用initializable_iterator

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_dataset = ds_train.shuffle().repeat() #repeat to control epochs and out of range error
train_dataset = train_dataset.batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.batch(1000)

train_iterator = train_dataset.make_initializable_iterator()
test_iterator = test_dataset.make_initializable_iterator()
with tf.Session() as sess:
    sess.run(train_iterator.initializer)
    sess.run(test_iterator.initializer)
    for i in steps:
        next_element = sess.run(train_iterator.get_next())

区别是您定义一次next_element,然后使用。例如:

with tf.Session() as sess:
        sess.run(train_iterator.initializer)
        sess.run(test_iterator.initializer)
        for i in steps:
            print(sess.run(train_iterator.get_next())) #This will print 1st element in dataset
            print(sess.run(train_iterator.get_next())) #This will print next element in dataset

with tf.Session() as sess:
            sess.run(train_iterator.initializer)
            sess.run(test_iterator.initializer)
            for i in steps:
                next_element = sess.run(train_iterator.get_next())) 
                print(next_element) #This will print 1st element in dataset
                print(next_element) #And this will print 1st element

答案 1 :(得分:0)

我是这样的

iterator_t = ds_t.make_initializable_iterator()
iterator_v = ds_v.make_initializable_iterator()

iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle, 
                                               iterator_t.output_types,
                                               iterator_t.output_shapes)

def get_next_item():
  # sometimes items need casting
  next_elem = iterator.get_next(name="next_element")
  x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]
  return x, y  

def old_data():
        # just forward the existing batch
        return inputs, target

is_keep_previous = tf.placeholder_with_default(tf.constant(False),shape=[], name="keep_previous_flag")

inputs, target =  tf.cond(is_keep_previous, old_data, new_data)

with tf.Session() as sess:
 sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])        
 handle_t = sess.run(iterator_t.string_handle())
 handle_v = sess.run(iterator_v.string_handle())
 # Run data iterator initialisation
 sess.run(iterator_t.initializer)
 sess.run(iterator_v.initializer)
 while True:
   try:
     inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:False})
     print(inputs_, target_)
     inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:True})
     print(inputs_, target_)
     inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_v})
     print(inputs_, target_)
   except tf.errors.OutOfRangeError:
     # now we know we run out of elements in the validationiterator
     break