我在tensorflow中有一个返回元组的数据集,因为它被设计用于创建功能和标签。我如何在python中迭代它并提取这些元组?
import tensorflow as tf
import numpy as np
sess = tf.Session()
def generator():
index = 0
while True:
feature = np.ones([1,4]) * index
label = feature[0:1,0:2]
print('yielding:', feature, label)
yield feature, label
index +=1
dataset = tf.data.Dataset.from_generator(
generator=generator,
output_types=(tf.float64, tf.float64),
output_shapes=(tf.TensorShape([1,4]),tf.TensorShape([1,2])),
)
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
print(data[0].eval(session=sess))
print(data[1].eval(session=sess))
print(data[0].eval(session=sess))
print(data[1].eval(session=sess))
输出:
yielding: [[0. 0. 0. 0.]] [[0. 0.]]
[[0. 0. 0. 0.]]
yielding: [[1. 1. 1. 1.]] [[1. 1.]]
[[1. 1.]]
yielding: [[2. 2. 2. 2.]] [[2. 2.]]
[[2. 2. 2. 2.]]
yielding: [[3. 3. 3. 3.]] [[3. 3.]]
[[3. 3.]]
问题是,每次我评估其中一个数据元素时,迭代器都会执行步骤。但是我无法获得与data[0]
和data[1]
相对应的相同步骤的值。
我正在寻找像eval_tuple(data)
这样会返回([[2. 2. 2. 2.]] [[2. 2.]])
的内容。
注意:我不是很注重评估我在上面的代码中创建的data
元组。目标是从dataset
对象中提取匹配的要素 - 标签对。