我在tensorflow stack 上阅读了tf.stack()
的文档。页面上有一个例子:
>>> x = tf.constant([1, 4])
>>> y = tf.constant([2, 5])
>>> z = tf.constant([3, 6])
>>> sess=tf.Session()
>>> sess.run(tf.stack([x, y, z]))
array([[1, 4],
[2, 5],
[3, 6]], dtype=int32)
>>> sess.run(tf.stack([x, y, z], axis=1))
array([[1, 2, 3],
[4, 5, 6]], dtype=int32)
我不明白的是axis=1
。
从结果看,它似乎将三个输入行转换为第一列
然后将它们放在axis=1
旁边,但是
我认为结果应该是
array([[1,4, 2, 5, 3, 6 ]] dtype=int32 )
任何人都可以帮忙解释一下吗?
谢谢!
答案 0 :(得分:3)
tf.stack
总是添加一个新维度,并始终沿着该新维度连接给定的张量。在您的情况下,您有三个形状为[2]
的张量。设置axis=0
与添加新的第一维相同,因此每个张量现在都具有形状[1, 2]
,并且在该维度上连接,因此最终形状将为[3, 2]
。也就是说,每个张量都是" row"最后的张量。使用axis=1
时,每个单个张量的形状将扩展为[2, 1]
,结果将具有[2, 3]
形状。所以每个给定的张量都是一个"列"由此产生的张量。
换句话说,tf.stack
在功能上等同于此:
def tf.stack(tensors, axis=0):
return tf.concatenate([tf.expand_dims(t, axis=axis) for t in tensors], axis=axis)
但是您期望的结果将通过以下方式获得:
tf.concatenate([tf.expand_dims(t, axis=0) for t in tensors], axis=1)
请注意,在这种情况下,添加的维度和连接维度是不同的。