具有选定索引的tensorflow掩码框

时间:2018-05-06 00:44:05

标签: numpy tensorflow

假设我有一个Rank-2张量 A [[1,1,1,1], [2,2,2,2],[3,3,3,3], [4, 4, 4, 4], ...],我有一个选定的索引 B (来自tf.equal()或其他地方)例如[1, 3, 4]。我希望 B 中的任何 i A [i] 为零,以便 A 最终变为类似[1,1,1,1], [0,0,0,0],[3,3,3,3], [0,0,0,0], ...]。怎么做或可能吗?

1 个答案:

答案 0 :(得分:1)

有多种方法可以做到这一点。这是一个tf.one_hot()(测试代码):

import tensorflow as tf

a = tf.constant( [[1,1,1,1], [2,2,2,2],[3,3,3,3], [4, 4, 4, 4]] )
b = tf.constant( [ 1, 3, 4 ] )

one_hot = tf.one_hot( b, a.get_shape()[ 0 ].value, dtype = a.dtype )
mask = 1 - tf.reduce_sum( one_hot, axis = 0 )
res = a * mask[ ..., None ]

with tf.Session() as sess:
    print( sess.run( res ) )

或带有tf.scatter_nd()的这个(测试代码):

import tensorflow as tf

a = tf.constant( [[1,1,1,1], [2,2,2,2], [3,3,3,3], [4, 4, 4, 4]] )
b = tf.constant( [ 1, 3 ] )

mask = 1 - tf.scatter_nd( b[ ..., None ], tf.ones_like( b ), shape = [ a.get_shape()[ 0 ].value ] )
res = a * mask[ ..., None ]

with tf.Session() as sess:
    print( sess.run( res ) )

将输出:

  

[[1 1 1 1]
   [0 0 0 0]
   [3 3 3 3]
   [0 0 0 0]]

根据需要。