如何在TensorFlow中沿维度选择张量的元素(而非切片)

时间:2019-04-16 03:03:49

标签: python tensorflow

我有一个3x2x4张量:

x = tf.reshape(tf.range(24), (3,2,4))
<tf.Tensor: id=1928, shape=(3, 2, 4), dtype=int64, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7]],

       [[ 8,  9, 10, 11],
        [12, 13, 14, 15]],

       [[16, 17, 18, 19],
        [20, 21, 22, 23]]])>

,我想通过沿第3维索引将其减少为3x2。这是索引向量的样子:

y = tf.constant(np.array([[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]))
<tf.Tensor: id=2093, shape=(3, 4), dtype=int64, numpy=
array([[0, 1, 0, 0],
       [0, 0, 1, 0],
       [1, 0, 0, 0]])>

所需的输出是:

<tf.Tensor: id=2103, shape=(3, 2), dtype=int64, numpy=
array([[ 1,  5],
       [10, 14],
       [16, 20]])>

我尝试了tf.batch_gather(x, y),但它给出了不同的输出。我需要collect_nd还是可以通过batch_gather解决?

1 个答案:

答案 0 :(得分:2)

您需要tf.boolean_mask()

import tensorflow as tf
import numpy as np

x = tf.reshape(tf.range(24), (3,2,4))
y = tf.constant(np.array([[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]))

result = tf.boolean_mask(tf.transpose(x,[0,2,1]),y)

with tf.Session() as sess:
    print(sess.run(result))

[[ 1  5]
 [10 14]
 [16 20]]