使用TensorFlow Data API手动获取下一个批次或使用相同的批次

时间:2018-10-11 13:50:43

标签: python tensorflow tensorflow-datasets

我正在尝试使用tf.Data API来加速我的代码并防止GPU数据匮乏,但是有一件事阻止了我对它的适应,并且可以在调用培训操作程序时使用同一批处理多次。

假设我将数据集设置为

dataset = tf.data.TextLineDataset("textfile.txt")
dataset = dataset.shuffle(dataset_size)
dataset = dataset.padded_batch(batch_size, ...)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch = iterator.get_next()

loss1 = someFunctionOf(x_batch)
loss2 = someOtherFunctionOf(x_batch)
train_op1 = someOptimizerOf(loss1)
train_op2 = someOtherOptimizerOf(loss2)

但是现在每当我调用train_op1时,都会调用iterator.get_next(),所以当调用train_op2时,我正在训练下一批。

this问题开始,我知道我可以使用flat_maprepeat(n)的组合,其中n是我要重复同一批次的次数此n取决于我要手动计数的train_ops的数量。另外,我需要这两个train_ops,因为它们优化了我图的不同部分。

谢谢您的帮助!

1 个答案:

答案 0 :(得分:0)

尝试下面的代码。它会创建输入和目标的副本,因此希望您在切换优化程序/ loss_op时它们不会更改。只要您不通过sess.run标志,它们就会在is_new:True个调用之间保持不变。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf


def ds_train(batch_size, num_epochs):  
    ds = (tf.data.Dataset.from_tensor_slices(([1.0,2.0,3.0,4.0,5.0], [-1,-2,-3,-4,-5]))
            .batch(batch_size)
            .repeat(num_epochs)        
            )
    return ds


batch_size = 1
input_size = 1
num_epochs = 2

with tf.variable_scope("dataset"):       
    ds_t = ds_train(batch_size, num_epochs)

with tf.variable_scope("iterator"):
    iterator_t = ds_t.make_initializable_iterator()
    iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
    iterator = tf.data.Iterator.from_string_handle(iterator_handle, 
                                                iterator_t.output_types,
                                                iterator_t.output_shapes)

    def next_item():
        next_elem = iterator.get_next(name="next_element")
        x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]# tf.cast(next_elem[1], tf.int32)
        return x, y        


inputs = tf.Variable(tf.zeros(shape=[batch_size,input_size]), dtype=tf.float32, name="inputs", trainable=False, use_resource=True)
target = tf.Variable(tf.zeros(shape=[batch_size], dtype=tf.int32), dtype=tf.int32, name="target", trainable=False,use_resource=True)
is_new = tf.placeholder_with_default(tf.constant(False), shape=[], name="new_item_flag")

def new_data(batch_size, input_size):
    # run the data layer to generate a new batch
    next_inputs, next_target = next_item()
    next_inputs = tf.reshape(next_inputs, shape=[batch_size, input_size])
    with tf.control_dependencies([tf.assign(inputs, next_inputs), tf.assign(target, next_target)]):
        return tf.identity(inputs), tf.identity(target)

def old_data():
    # just forward the existing batch
    return inputs, target

next_inputs, next_target = next_item()

inputs, target =  tf.cond(is_new, lambda:new_data(batch_size, input_size), old_data)

with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
    handle_t = sess.run(iterator_t.string_handle())
    sess.run(iterator_t.initializer)
    while True:
        try:
            print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
            print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
            print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: True}))
        except tf.errors.OutOfRangeError:
            print("End of training dataset.")
            break        
相关问题