train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_dataset = train_dataset.batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.batch(1000)
iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
train_dataset.output_shapes)
next_element_x, next_element_y = iterator.get_next()
training_init_op = iterator.make_initializer(train_dataset)
testing_init_op = iterator.make_initializer(test_dataset)
logits = DenseNet(x=next_element_x, nb_blocks=nb_block, filters=growth_k, training=training_flag).model
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=next_element_y, logits=logits))
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(next_element_y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
我正在尝试使用输入管道来提高代码的性能。我使用可重新初始化的迭代器来实现它。
在我的代码中,由next_element_y
生成的iterator.get_next()
在每次迭代中都要使用两次(成本和correct_prediction)。
因此,我有782次迭代/时期,并且在391次迭代后出现了超出范围的错误。
如何两次使用next_element_y
而不在一次迭代中额外触发迭代器一次?
答案 0 :(得分:0)
您可以只使用initializable_iterator
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_dataset = ds_train.shuffle().repeat() #repeat to control epochs and out of range error
train_dataset = train_dataset.batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.batch(1000)
train_iterator = train_dataset.make_initializable_iterator()
test_iterator = test_dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run(train_iterator.initializer)
sess.run(test_iterator.initializer)
for i in steps:
next_element = sess.run(train_iterator.get_next())
区别是您定义一次next_element
,然后使用。例如:
with tf.Session() as sess:
sess.run(train_iterator.initializer)
sess.run(test_iterator.initializer)
for i in steps:
print(sess.run(train_iterator.get_next())) #This will print 1st element in dataset
print(sess.run(train_iterator.get_next())) #This will print next element in dataset
with tf.Session() as sess:
sess.run(train_iterator.initializer)
sess.run(test_iterator.initializer)
for i in steps:
next_element = sess.run(train_iterator.get_next()))
print(next_element) #This will print 1st element in dataset
print(next_element) #And this will print 1st element
答案 1 :(得分:0)
我是这样的
iterator_t = ds_t.make_initializable_iterator()
iterator_v = ds_v.make_initializable_iterator()
iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle,
iterator_t.output_types,
iterator_t.output_shapes)
def get_next_item():
# sometimes items need casting
next_elem = iterator.get_next(name="next_element")
x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]
return x, y
def old_data():
# just forward the existing batch
return inputs, target
is_keep_previous = tf.placeholder_with_default(tf.constant(False),shape=[], name="keep_previous_flag")
inputs, target = tf.cond(is_keep_previous, old_data, new_data)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
handle_t = sess.run(iterator_t.string_handle())
handle_v = sess.run(iterator_v.string_handle())
# Run data iterator initialisation
sess.run(iterator_t.initializer)
sess.run(iterator_v.initializer)
while True:
try:
inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:False})
print(inputs_, target_)
inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:True})
print(inputs_, target_)
inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_v})
print(inputs_, target_)
except tf.errors.OutOfRangeError:
# now we know we run out of elements in the validationiterator
break