带索引的遮罩张量

时间:2020-03-25 15:37:18

标签: python tensorflow mask

给定大小为x的2D张量(batch_size, D)以及大小分别为m1的掩蔽索引向量m2(batch_size, ),如何掩盖张量从而将i的行x的索引中小于m1[i]或大于m2[i]的索引的值设置为0?请注意,m1m2引用索引(不是值),保证m1[i]小于或等于m2[i],并且不同的行可能具有不同数量的值没有被掩盖的。

例如,x的大小为(2, 4)m1m2的大小为(2, )

# input
x = tf.Variable([[1., 2., -1., 5.], [4., -3., 3., -2]])
m1 = tf.Variable([1, 2])
m2 = tf.Variable([2, 2])

# desired masked result
y = tf.Variable([[0., 2., -1., 0.], [0., 0., 3., 0.]])

这里,第一行有两个非掩码值,而第二行只有一个。我正在使用TensorFlow 2。

0 个答案:

没有答案