对于任何2D张量
[[2,5,4,7], [7,5,6,8]],
我想对每一行中的前 k 个元素进行softmax运算,然后通过将所有其他元素替换为0来构造一个新的张量。
结果应该是获取每一行[[7,5],[8,7]]的顶部 k 个元素(此处为k = 2)的softmax, 因此 [[0.880797,0.11920291], [0.7310586,0.26894143] 然后根据原始张量中前 k 个元素的索引重建新的张量,最终结果应为
[[0,0.11920291,0,0.880797], [0.26894143,0,0,0.7310586]。
有可能在张量流中实现这种 masked softmax 吗?提前非常感谢!
答案 0 :(得分:4)
这是您可以执行的操作:
import tensorflow as tf
# Input data
a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Apply softmax
a_top_sm = tf.nn.softmax(a_top)
# Reconstruct into original shape
a_shape = tf.shape(a)
a_row_idx = tf.tile(tf.range(a_shape[0])[:, tf.newaxis], (1, num_top))
scatter_idx = tf.stack([a_row_idx, a_top_idx], axis=-1)
result = tf.scatter_nd(scatter_idx, a_top_sm, a_shape)
# Test
with tf.Session() as sess:
result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
print(result_val)
输出:
[[0. 0.11920291 0. 0.880797 ]
[0.26894143 0. 0. 0.7310586 ]]
编辑:
实际上,有一个功能可以更紧密地实现您的预期tf.sparse.softmax
。但是,它需要一个SparseTensor
作为输入,我不确定它应该更快,因为它必须确定softmax中哪些稀疏值会一起出现。关于此函数的好处是,每行中可以有不同数量的元素以达到softmax,但是在您看来,这并不重要。无论如何,这是一个实现,以防您发现它有用。
import tensorflow as tf
a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Flatten values
sparse_values = tf.reshape(a_top, [-1])
# Make sparse indices
shape = tf.cast(tf.shape(a), tf.int64)
a_row_idx = tf.tile(tf.range(shape[0])[:, tf.newaxis], (1, num_top))
sparse_idx = tf.stack([a_row_idx, tf.cast(a_top_idx, tf.int64)], axis=-1)
sparse_idx = tf.reshape(sparse_idx, [-1, 2])
# Make sparse tensor
a_top_sparse = tf.SparseTensor(sparse_idx, sparse_values, shape)
# Reorder sparse tensor
a_top_sparse = tf.sparse.reorder(a_top_sparse)
# Softmax
result_sparse = tf.sparse.softmax(a_top_sparse)
# Convert back to dense (or you can keep working with the sparse tensor)
result = tf.sparse.to_dense(result_sparse)
# Test
with tf.Session() as sess:
result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
print(result_val)
# Same as before
答案 1 :(得分:0)
假设您的权重张量w
的形状为(None, N)
找到前k
个元素的最小值
top_kw = tf.math.top_k(w, k=10, sorted=False)[0]
min_w = tf.reduce_min(top_kw, axis=1, keepdims=True)
为权重张量生成布尔掩码
mask_w = tf.greater_equal(w, min_w)
mask_w = tf.cast(mask_w, tf.float32)
使用遮罩计算自定义softmax
w = tf.multiply(tf.exp(w), mask_w) / tf.reduce_sum(tf.multiply(tf.exp(w), mask_w), axis=1, keepdims=True)