例如,有一个张量
a=[[1,2,3,4,5],
[2,3,4,5,6]]
indices =[[1, 0, 1, 0, 0],
[0, 1, 0, 0, 0]]
我只想对索引值为1(来自b)的元素(来自a)使用激活。例如,我只想对带有索引[0,0],[0,2],[1,1]的元素使用激活函数。
谢谢!
答案 0 :(得分:1)
您可以使用tf.where:
tf.where(tf.cast(indices, dtype=tf.bool), tf.nn.sigmoid(a), a)
对于你的例子:
import tensorflow as tf
a = tf.constant([[1,2,3,4,5], [2,3,4,5,6]], dtype=tf.float32)
indices = tf.constant([[1, 0, 1, 0, 0], [0, 1, 0, 0, 0]],
dtype = tf.int32)
result = tf.where(tf.cast(indices, dtype=tf.bool), tf.nn.sigmoid(a), a)
with tf.Session() as sess:
print(sess.run(result))
打印:
[[ 0.7310586 2. 0.95257413 4. 5. ]
[ 2. 0.95257413 4. 5. 6 ]]