从张量流数据集中获取两个值

时间:2018-05-15 07:09:30

标签: python tensorflow

我在tensorflow中有一个返回元组的数据集,因为它被设计用于创建功能和标签。我如何在python中迭代它并提取这些元组?

import tensorflow as tf
import numpy as np

sess = tf.Session()

def generator():
    index = 0
    while True:
        feature = np.ones([1,4]) * index
        label = feature[0:1,0:2]
        print('yielding:', feature, label)
        yield feature, label
        index +=1


dataset = tf.data.Dataset.from_generator(
    generator=generator,
    output_types=(tf.float64, tf.float64),
    output_shapes=(tf.TensorShape([1,4]),tf.TensorShape([1,2])),
)

iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()

print(data[0].eval(session=sess))
print(data[1].eval(session=sess))
print(data[0].eval(session=sess))
print(data[1].eval(session=sess))

输出:

yielding: [[0. 0. 0. 0.]] [[0. 0.]]
[[0. 0. 0. 0.]]
yielding: [[1. 1. 1. 1.]] [[1. 1.]]
[[1. 1.]]
yielding: [[2. 2. 2. 2.]] [[2. 2.]]
[[2. 2. 2. 2.]]
yielding: [[3. 3. 3. 3.]] [[3. 3.]]
[[3. 3.]]

问题是,每次我评估其中一个数据元素时,迭代器都会执行步骤。但是我无法获得与data[0]data[1]相对应的相同步骤的值。

我正在寻找像eval_tuple(data)这样会返回([[2. 2. 2. 2.]] [[2. 2.]])的内容。

注意:我不是很注重评估我在上面的代码中创建的data元组。目标是从dataset对象中提取匹配的要素 - 标签对。

0 个答案:

没有答案