我们用tf.stack做什么([tf.range(tf.shape(self.a)[0],dtype = tf.int32),self.a],axis = 1)

时间:2018-04-29 04:47:59

标签: python tensorflow

这个命令做了什么? tf.stack代表什么?

tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)

1 个答案:

答案 0 :(得分:1)

一般来说,在我解决问题时,我会尝试将NumPy等效于TensorFlow功能。最初,TensorFlow API与NumPy API存在一些奇怪的差异,但是足够的用户希望这两个包的行为与TensorFlow进行更改的行为相同。

你说数组self.a保证是1D。那好吧:

import numpy as np
arr = np.random.randint(-9,9,(10,))
print(arr)
result = np.stack([np.arange(np.shape(arr)[0], dtype=np.int32), arr], axis=1)
print(result)

这是一个示例输出:

array([-5,  1,  0, -3, -9, -8,  3, -1,  0, -2])

array([[ 0, -5],
       [ 1,  1],
       [ 2,  0],
       [ 3, -3],
       [ 4, -9],
       [ 5, -8],
       [ 6,  3],
       [ 7, -1],
       [ 8,  0],
       [ 9, -2]])

因此,看起来原始的1D数组被放大为第2列中带有数字索引的2D数组。