使用tf.Dataset.batch
时,get_next()
将在调用时预加载一些数据。似乎有一个后台线程正在执行此操作。有办法禁用它吗?
复制代码段:
import tensorflow as tf
def pr(x):
print(x)
return x
dataset = tf.data.Dataset.range(10000)
dataset = dataset.map(lambda x: tf.py_func(pr, [x], [tf.int64]))
dataset = dataset.batch(3)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)
next_element = iterator.get_next()
for i in range(2):
fetches = sess.run(next_element)
print(fetches)
不稳定的样本输出如下:
0
1
2
3
(array([0, 1, 2]),)
4
5
6
(array([3, 4, 5]),)
7
8
我希望确定性输出为:
0
1
2
(array([0, 1, 2]),)
3
4
5
(array([3, 4, 5]),)
环境为CPU模式下的OS X + python3.7.2 + tensorflow1.13.1
答案 0 :(得分:1)
好的,感谢giser_yugang的评论。我从1.13的ChangeLog中发现了提示。 (https://github.com/tensorflow/tensorflow/releases/tag/v1.13.1)。
设置数据集选项可在1.13版中解决此问题
import tensorflow as tf
def pr(x):
print(x)
return x
dataset = tf.data.Dataset.range(10000)
options = tf.data.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
dataset = dataset.map(lambda x: tf.py_func(pr, [x], [tf.int64]))
dataset = dataset.batch(3)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)
next_element = iterator.get_next()
for i in range(2):
fetches = sess.run(next_element)
print(fetches)