我有一个3x2x4
张量:
x = tf.reshape(tf.range(24), (3,2,4))
<tf.Tensor: id=1928, shape=(3, 2, 4), dtype=int64, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])>
,我想通过沿第3维索引将其减少为3x2
。这是索引向量的样子:
y = tf.constant(np.array([[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]))
<tf.Tensor: id=2093, shape=(3, 4), dtype=int64, numpy=
array([[0, 1, 0, 0],
[0, 0, 1, 0],
[1, 0, 0, 0]])>
所需的输出是:
<tf.Tensor: id=2103, shape=(3, 2), dtype=int64, numpy=
array([[ 1, 5],
[10, 14],
[16, 20]])>
我尝试了tf.batch_gather(x, y)
,但它给出了不同的输出。我需要collect_nd还是可以通过batch_gather解决?
答案 0 :(得分:2)
您需要tf.boolean_mask()
。
import tensorflow as tf
import numpy as np
x = tf.reshape(tf.range(24), (3,2,4))
y = tf.constant(np.array([[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]))
result = tf.boolean_mask(tf.transpose(x,[0,2,1]),y)
with tf.Session() as sess:
print(sess.run(result))
[[ 1 5]
[10 14]
[16 20]]