这个命令做了什么? tf.stack
代表什么?
tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)
答案 0 :(得分:1)
一般来说,在我解决问题时,我会尝试将NumPy等效于TensorFlow功能。最初,TensorFlow API与NumPy API存在一些奇怪的差异,但是足够的用户希望这两个包的行为与TensorFlow进行更改的行为相同。
你说数组self.a
保证是1D。那好吧:
import numpy as np
arr = np.random.randint(-9,9,(10,))
print(arr)
result = np.stack([np.arange(np.shape(arr)[0], dtype=np.int32), arr], axis=1)
print(result)
这是一个示例输出:
array([-5, 1, 0, -3, -9, -8, 3, -1, 0, -2])
array([[ 0, -5],
[ 1, 1],
[ 2, 0],
[ 3, -3],
[ 4, -9],
[ 5, -8],
[ 6, 3],
[ 7, -1],
[ 8, 0],
[ 9, -2]])
因此,看起来原始的1D数组被放大为第2列中带有数字索引的2D数组。