我想从3d张量中提取元素以产生2d张量。 我有3d张量(2000 X 2 X 3)和3-dim(2000)的指数1d张量。 指数张量肯定包含2000个元素的指数0~2。
事实上,我希望A[:,:,inds]
获得相同的结果。
如何使用tf.gather_nd
,请帮助我。
答案 0 :(得分:0)
如果张量的大小2000 x 2 x 3
,您可以使用tf.gather
:
a = tf.random_normal([2000, 2, 3])
b = tf.gather(a, [1007, 8, 7, 9])
b.get_shape()
TensorShape([Dimension(4), Dimension(2), Dimension(3)])
如果形状为2 x 3 x 2000
,您可以:
2000 x 2 x 3
将形状更改为tf.reshape
,然后执行相同的操作,如上所述使用常规python机制进行索引,然后pack张量回来:
a = tf.random_normal([2, 3, 2000])
indices = [1007, 8, 7, 9]
subtensor = [a[:, :, i] for i in indices]
b = tf.pack(subtensor, axis=2)
b.get_shape()
<tf.Tensor 'pack:0' shape=(2, 3, 4) dtype=float32>