关于tf.stack轴的问题()

时间:2018-06-12 15:27:39

标签: tensorflow

我在tensorflow stack 上阅读了tf.stack()的文档。页面上有一个例子:

>>> x = tf.constant([1, 4])
>>> y = tf.constant([2, 5])
>>> z = tf.constant([3, 6])
>>> sess=tf.Session()
>>> sess.run(tf.stack([x, y, z]))
array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> sess.run(tf.stack([x, y, z], axis=1))
array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)

我不明白的是axis=1

的第二个例子

从结果看,它似乎将三个输入行转换为第一列

然后将它们放在axis=1旁边,但是

我认为结果应该是

array([[1,4, 2, 5, 3, 6 ]] dtype=int32 )

任何人都可以帮忙解释一下吗?

谢谢!

1 个答案:

答案 0 :(得分:3)

tf.stack总是添加一个新维度,并始终沿着该新维度连接给定的张量。在您的情况下,您有三个形状为[2]的张量。设置axis=0与添加新的第一维相同,因此每个张量现在都具有形状[1, 2],并且在该维度上连接,因此最终形状将为[3, 2]。也就是说,每个张量都是" row"最后的张量。使用axis=1时,每个单个张量的形状将扩展为[2, 1],结果将具有[2, 3]形状。所以每个给定的张量都是一个"列"由此产生的张量。

换句话说,tf.stack在功能上等同于此:

def tf.stack(tensors, axis=0):
    return tf.concatenate([tf.expand_dims(t, axis=axis) for t in tensors], axis=axis)

但是您期望的结果将通过以下方式获得:

tf.concatenate([tf.expand_dims(t, axis=0) for t in tensors], axis=1)

请注意,在这种情况下,添加的维度和连接维度是不同的。