当我编写Tensorflow代码时,我会尽量记住不同事物的类型,例如两个张量的元组或张量列表。这很重要,因为当类型/形状不匹配时,Tensorflow会发出错误。
这个问题标题的文本在文档中显示了很多,尤其是在描述某些函数的结果时,例如Iterator.get_next()
,但我觉得它含糊不清。它并没有告诉我确切的期望,元组列表?元组的元组? “嵌套结构”到底是什么?现在,我可以跟踪的唯一方法是在Session.run()之后打印结果。有没有更清洁,更确定的方法?
此外,似乎Iterator.get_next()
的值始终是一个元素的列表;我无法使其返回非列表,空列表或包含多个元素的列表。 Iterator.get_next()
何时返回不是一个元素的列表的内容?如果从来没有,那么将内容包装在列表中似乎是多余的-为什么Iterator.get_next()
这样设计?
这是示例代码,显示了我的意思:
import numpy as np
import tensorflow as tf
ds = tf.data.Dataset.from_tensor_slices(np.array(range(0, 8)).reshape(4,2))
it = ds.make_one_shot_iterator()
with tf.Session() as sess:
for i in range(0, 4):
x = sess.run([it.get_next()])
print(x)
输出:
[array([0, 1])]
[array([2, 3])]
[array([4, 5])]
[array([6, 7])]
为什么不只是以下内容?
array([0, 1])
array([2, 3])
array([4, 5])
array([6, 7])
答案 0 :(得分:1)
您的特定问题“输出”与“为什么不只是以下内容?”将列表传递到sess.run
的结果。如果改用sess.run(it.get_next())
,则会获得所需的行为。
请注意,TensorFlow允许您传入几种不同的结构(例如,列表,字典,命名元组等),包括嵌套结构(请参见here)。它将按照与您传入数据时相同的结构返回数据。
例如,带有字典:
import numpy as np
import tensorflow as tf
ds = tf.data.Dataset.from_tensor_slices(np.array(range(0, 8)).reshape(4,2))
it = ds.make_one_shot_iterator()
with tf.Session() as sess:
for i in range(0, 4):
x = sess.run({'x': it.get_next()}])
print(x)
输出:
{'x': array([0, 1])}
{'x': array([2, 3])}
{'x': array([4, 5])}
{'x': array([6, 7])}