我正在尝试使用tf.Data API来加速我的代码并防止GPU数据匮乏,但是有一件事阻止了我对它的适应,并且可以在调用培训操作程序时使用同一批处理多次。
假设我将数据集设置为
dataset = tf.data.TextLineDataset("textfile.txt")
dataset = dataset.shuffle(dataset_size)
dataset = dataset.padded_batch(batch_size, ...)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch = iterator.get_next()
loss1 = someFunctionOf(x_batch)
loss2 = someOtherFunctionOf(x_batch)
train_op1 = someOptimizerOf(loss1)
train_op2 = someOtherOptimizerOf(loss2)
但是现在每当我调用train_op1
时,都会调用iterator.get_next()
,所以当调用train_op2
时,我正在训练下一批。
从this问题开始,我知道我可以使用flat_map
和repeat(n)
的组合,其中n
是我要重复同一批次的次数此n
取决于我要手动计数的train_ops
的数量。另外,我需要这两个train_ops
,因为它们优化了我图的不同部分。
谢谢您的帮助!
答案 0 :(得分:0)
尝试下面的代码。它会创建输入和目标的副本,因此希望您在切换优化程序/ loss_op时它们不会更改。只要您不通过sess.run
标志,它们就会在is_new:True
个调用之间保持不变。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def ds_train(batch_size, num_epochs):
ds = (tf.data.Dataset.from_tensor_slices(([1.0,2.0,3.0,4.0,5.0], [-1,-2,-3,-4,-5]))
.batch(batch_size)
.repeat(num_epochs)
)
return ds
batch_size = 1
input_size = 1
num_epochs = 2
with tf.variable_scope("dataset"):
ds_t = ds_train(batch_size, num_epochs)
with tf.variable_scope("iterator"):
iterator_t = ds_t.make_initializable_iterator()
iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle,
iterator_t.output_types,
iterator_t.output_shapes)
def next_item():
next_elem = iterator.get_next(name="next_element")
x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]# tf.cast(next_elem[1], tf.int32)
return x, y
inputs = tf.Variable(tf.zeros(shape=[batch_size,input_size]), dtype=tf.float32, name="inputs", trainable=False, use_resource=True)
target = tf.Variable(tf.zeros(shape=[batch_size], dtype=tf.int32), dtype=tf.int32, name="target", trainable=False,use_resource=True)
is_new = tf.placeholder_with_default(tf.constant(False), shape=[], name="new_item_flag")
def new_data(batch_size, input_size):
# run the data layer to generate a new batch
next_inputs, next_target = next_item()
next_inputs = tf.reshape(next_inputs, shape=[batch_size, input_size])
with tf.control_dependencies([tf.assign(inputs, next_inputs), tf.assign(target, next_target)]):
return tf.identity(inputs), tf.identity(target)
def old_data():
# just forward the existing batch
return inputs, target
next_inputs, next_target = next_item()
inputs, target = tf.cond(is_new, lambda:new_data(batch_size, input_size), old_data)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
handle_t = sess.run(iterator_t.string_handle())
sess.run(iterator_t.initializer)
while True:
try:
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: True}))
except tf.errors.OutOfRangeError:
print("End of training dataset.")
break