默认情况下,tf.Dataset.batch是否预加载以及如何禁用?

时间:2019-04-06 04:23:09

标签: python tensorflow tensorflow-datasets

使用tf.Dataset.batch时,get_next()将在调用时预加载一些数据。似乎有一个后台线程正在执行此操作。有办法禁用它吗?

复制代码段:

import tensorflow as tf

def pr(x):
    print(x)
    return x


dataset = tf.data.Dataset.range(10000)
dataset = dataset.map(lambda x: tf.py_func(pr, [x], [tf.int64]))

dataset = dataset.batch(3)

iterator = dataset.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(iterator.initializer)
    next_element = iterator.get_next()

    for i in range(2):
        fetches = sess.run(next_element)
        print(fetches)

不稳定的样本输出如下:

0
1
2
3
(array([0, 1, 2]),)
4
5
6
(array([3, 4, 5]),)
7
8

我希望确定性输出为:

0
1
2
(array([0, 1, 2]),)
3
4
5
(array([3, 4, 5]),)

环境为CPU模式下的OS X + python3.7.2 + tensorflow1.13.1

1 个答案:

答案 0 :(得分:1)

好的,感谢giser_yugang的评论。我从1.13的ChangeLog中发现了提示。 (https://github.com/tensorflow/tensorflow/releases/tag/v1.13.1)。

设置数据集选项可在1.13版中解决此问题


import tensorflow as tf

def pr(x):
    print(x)
    return x


dataset = tf.data.Dataset.range(10000)

options = tf.data.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)

dataset = dataset.map(lambda x: tf.py_func(pr, [x], [tf.int64]))

dataset = dataset.batch(3)

iterator = dataset.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(iterator.initializer)
    next_element = iterator.get_next()

    for i in range(2):
        fetches = sess.run(next_element)
        print(fetches)