在2018年TensorFlow开发者峰会的tf.data talk中,Derek Murray提出了一种方法,将tf.data
API与TensorFlow的急切执行模式(10:54)结合起来。我试用了那里显示的代码的简化版本:
import tensorflow as tf
tf.enable_eager_execution()
dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([50, 10]))
dataset = dataset.batch(5)
for batch in dataset:
print(batch)
造成
TypeError: 'BatchDataset' object is not iterable
我还尝试使用dataset.make_one_shot_iterator()
和dataset.make_initializable_iterator()
迭代数据集,但结果却是
RuntimeError: dataset.make_one_shot_iterator is not supported when eager execution is enabled.
和
RuntimeError: dataset.make_initializable_iterator is not supported when eager execution is enabled.
TensorFlow版本:1.7.0,Python版本:3.6
如何将tf.data
API与急切执行一起使用?
答案 0 :(得分:7)
make_one_shot_iterator()
应该在TensorFlow 1.8中工作,但是现在(例如,对于TensorFlow 1.7),请执行以下操作:
import tensorflow.contrib.eager as tfe
dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([50, 10]))
dataset = dataset.batch(5)
for batch in tfe.Iterator(dataset):
print(batch)
答案 1 :(得分:3)
使用 TF 2.1 ,
您可以这样创建一个迭代器:
iterator = iter(dataset)
并获取下一批值:
batch = iterator.get_next()