张量流中使用tf.gather_nd的3d到2d张量

时间:2016-12-12 09:03:39

标签: indexing tensorflow

我想从3d张量中提取元素以产生2d张量。 我有3d张量(2000 X 2 X 3)和3-dim(2000)的指数1d张量。 指数张量肯定包含2000个元素的指数0~2。

事实上,我希望A[:,:,inds]获得相同的结果。 如何使用tf.gather_nd,请帮助我。

1 个答案:

答案 0 :(得分:0)

如果张量的大小2000 x 2 x 3,您可以使用tf.gather

a = tf.random_normal([2000, 2, 3])
b = tf.gather(a, [1007, 8, 7, 9])
b.get_shape()
TensorShape([Dimension(4), Dimension(2), Dimension(3)])

如果形状为2 x 3 x 2000,您可以:

  • 使用2000 x 2 x 3将形状更改为tf.reshape,然后执行相同的操作,如上所述
  • 使用常规python机制进行索引,然后pack张量回来:

    a = tf.random_normal([2, 3, 2000])
    indices = [1007, 8, 7, 9]
    subtensor = [a[:, :, i] for i in indices]
    b = tf.pack(subtensor, axis=2)
    b.get_shape()
    <tf.Tensor 'pack:0' shape=(2, 3, 4) dtype=float32>