Iterator.get_next()结果类型“ {tf.Tensor对象的嵌套结构”

时间:2018-10-10 02:10:19

标签: python tensorflow

当我编写Tensorflow代码时,我会尽量记住不同事物的类型,例如两个张量的元组或张量列表。这很重要,因为当类型/形状不匹配时,Tensorflow会发出错误。

这个问题标题的文本在文档中显示了很多,尤其是在描述某些函数的结果时,例如Iterator.get_next(),但我觉得它含糊不清。它并没有告诉我确切的期望,元组列表?元组的元组? “嵌套结构”到底是什么?现在,我可以跟踪的唯一方法是在Session.run()之后打印结果。有没有更清洁,更确定的方法?

此外,似乎Iterator.get_next()的值始终是一个元素的列表;我无法使其返回非列表,空列表或包含多个元素的列表。 Iterator.get_next()何时返回不是一个元素的列表的内容?如果从来没有,那么将内容包装在列表中似乎是多余的-为什么Iterator.get_next()这样设计?

这是示例代码,显示了我的意思:

import numpy as np
import tensorflow as tf

ds = tf.data.Dataset.from_tensor_slices(np.array(range(0, 8)).reshape(4,2))
it = ds.make_one_shot_iterator()

with tf.Session() as sess:
    for i in range(0, 4):
        x = sess.run([it.get_next()])
        print(x)

输出:

[array([0, 1])]
[array([2, 3])]
[array([4, 5])]
[array([6, 7])]

为什么不只是以下内容?

array([0, 1])
array([2, 3])
array([4, 5])
array([6, 7])

1 个答案:

答案 0 :(得分:1)

您的特定问题“输出”与“为什么不只是以下内容?”将列表传递到sess.run的结果。如果改用sess.run(it.get_next()),则会获得所需的行为。

请注意,TensorFlow允许您传入几种不同的结构(例如,列表,字典,命名元组等),包括嵌套结构(请参见here)。它将按照与您传入数据时相同的结构返回数据。

例如,带有字典:

import numpy as np
import tensorflow as tf

ds = tf.data.Dataset.from_tensor_slices(np.array(range(0, 8)).reshape(4,2))
it = ds.make_one_shot_iterator()

with tf.Session() as sess:
    for i in range(0, 4):
        x = sess.run({'x': it.get_next()}])
        print(x)

输出:

{'x': array([0, 1])}
{'x': array([2, 3])}
{'x': array([4, 5])}
{'x': array([6, 7])}