使用TensorFlow Dataset API的Epoch计数器

时间:2017-11-21 10:27:47

标签: python tensorflow

我正在将TensorFlow代码从旧队列界面更改为新Dataset API。在我的旧代码中,每次在队列中访问和处理新的输入张量时,我都会通过递增tf.Variable来跟踪纪元数。我想用新的数据集API来计算这个时代,但是我在使用它时遇到了一些麻烦。

由于我在预处理阶段生成了可变数量的数据项,因此在训练循环中递增(Python)计数器并不是一件简单的事情 - 我需要计算相对于输入队列或数据集。

我使用旧的队列系统模仿我以前所拥有的东西,这就是我最终得到的数据集API(简化示例):

with tf.Graph().as_default():

    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
                                trainable=False)

    def pre_processing_func(data_):
        data_size = tf.constant(0.1, dtype=tf.float32)
        epoch_counter_op = tf.assign_add(epoch_counter, data_size)
        with tf.control_dependencies([epoch_counter_op]):
            # normally I would do data-augmentation here
            results = (tf.expand_dims(data_, axis=0),)
            return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    dataset = dataset.repeat()
    # ... do something with 'dataset' and print
    # the value of 'epoch_counter' every once a while

然而,这不起作用。它崩溃了一个神秘的错误信息:

 TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32])
 are not compatible with expected types ([tf.float32_ref, tf.float32])

仔细检查表明epoch_counter可能根本无法访问pre_processing_func变量。它可能生活在不同的图表中吗?

知道如何修复上面的例子吗?或者如何通过其他方式获得纪元计数器(带小数点,例如0.4或2.9)?

3 个答案:

答案 0 :(得分:7)

TL; DR :将epoch_counter的定义替换为以下内容:

epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
                                trainable=False, use_resource=True)

tf.data.Dataset转换中使用TensorFlow变量存在一些限制。原则限制是所有变量必须是“资源变量”而不是旧的“参考变量”;遗憾的是tf.Variable仍然会出于向后兼容性原因创建“参考变量”。

一般来说,如果可以避免使用变量,我建议不要在tf.data管道中使用变量。例如,您可以使用Dataset.range()来定义纪元计数器,然后执行以下操作:

epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
    (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))

上面的代码片段将每个值附加一个纪元计数器作为第二个组件。

答案 1 :(得分:0)

data = data.repeat(num_epochs)行导致重复已经为num_epochs个重复的数据集(也是纪元计数器)。将for _ in range(num_iters):替换为for _ in range(num_iters+1):即可轻松获得。

答案 2 :(得分:0)

我将numerica的示例代码扩展为批次,并替换了itertool部分:

num_examples = 5
num_epochs = 4
batch_size = 2
num_iters = int(num_examples * num_epochs / batch_size)

features = tf.data.Dataset.range(num_examples)
labels = tf.data.Dataset.range(num_examples)

data = tf.data.Dataset.zip((features, labels))
data = data.shuffle(num_examples)

epoch = tf.data.Dataset.range(num_epochs)
data = epoch.flat_map(
    lambda i: tf.data.Dataset.zip((
        data,
        tf.data.Dataset.from_tensors(i).repeat(),
        tf.data.Dataset.range(num_examples)
    ))
)

# to flatten the nested datasets
data = data.map(lambda samples, *cnts: samples+cnts )
data = data.batch(batch_size=batch_size)

it = data.make_one_shot_iterator()
x, y, ep, st = it.get_next()

with tf.Session() as sess:
    for _ in range(num_iters):
        x_, y_, ep_, st_ = sess.run([x, y, ep, st])
        print(f'step {st_}\t epoch {ep_} \t x {x_} \t y {y_}')