我尝试对选定的索引执行softmax,使用无穷大掩码来清除不需要的索引。但是,与nan
相比,这些不需要的内容的渐变变为0
。
我没有使用布尔掩码的原因是我的批处理中的掩码索引不同,这不能以一个漂亮的矩阵形式结束。如果有解决方法,我将非常乐意采纳。
我测试了无限模板的代码是
import numpy as np
import tensorflow as tf
a = tf.placeholder(tf.float32, [5])
inf_mask = tf.placeholder(tf.float32, [5])
b = tf.multiply(a, inf_mask)
sf = tf.nn.softmax(b)
loss = (sf[2] - 0)
grad = tf.gradients(loss, a)
sess = tf.Session()
a_np = np.ones([5])
np_mask = np.ones([5]) * 4
np_mask[1] = -np.inf
print sess.run([sf, grad], feed_dict={
a: a_np,
inf_mask: np_mask
})
sess.close()
输出
[array([ 0.25, 0. , 0.25, 0.25, 0.25], dtype=float32), [array([-0.25, nan, 0.75, -0.25, -0.25], dtype=float32)]]
蒙版正在运作,但渐变有一个nan
,我认为它应该是0
。