我有一个3d张量,我需要在第二维中的某些位置保留向量,并将剩余的向量归零。位置指定为1d数组。我认为最好的方法是将张量乘以二进制掩码。
这是一个简单的Numpy版本:
A.shape: (b, n, m)
indices.shape: (b)
mask = np.zeros(A.shape)
for i in range(b):
mask[i][indices[i]] = 1
result = A*mask
因此对于A中的每个nxm矩阵,我需要保留由indices指定的行,并将其余部分归零。
我尝试使用tf.scatter_nd操作在TensorFlow中执行此操作,但我无法找出正确的索引形状:
shape = tf.constant([3,5,4])
A = tf.random_normal(shape)
indices = tf.constant([2,1,4]) #???
updates = tf.ones((3,4))
mask = tf.scatter_nd(indices, updates, shape)
result = A*mask