我有一个2D张量,形状为(M,N)。我想得到一个掩码,对于每一行,给定张量的top-k为1,其他为0。 例如,张量为:
[[1,3,5,7],
[2,4,7,0],
[9,3,1,5]]
如果将topk设置为1,则掩码应为:
[[0,0,0,1],
[0,0,1,0],
[1,0,0,0]]
如果将topk设置为2,则掩码应为:
[[0,0,1,1],
[0,1,1,0],
[1,0,0,1]]
我想出一种非常繁琐的方法:
_, nn_idx = tf.nn.top_k(tmp, top_k) # the shape of tmp is (M,N) and the shape of nn_idx is (M,top_k)
nn_idx_one = tf.reshape(nn_idx, [-1, 1])
nn_idx_multi_hot = tf.one_hot(nn_idx_one, depth=N) # (M*top_k,1) -> (M*top_k,N), N is about 100000 (very big)
nn_idx_multi_hot = tf.reshape(nn_idx_multi_hot, [-1, top_k, N]) # (M,top_k,N)
nn_idx_multi_hot = tf.reduce_sum(nn_idx_multi_hot, axis=1) # (M,N)
mask_a = tf.ones(shape=[M, N])
mask_b = tf.zeros(shape=[M, N])
mask = tf.where(nn_idx_multi_hot > 0, mask_a , mask_b ) # the target mask
此操作占用大量内存,我想有一种简洁的方法,有人可以帮助我吗?
答案 0 :(得分:0)
您可以使用tf.scatter_nd
来做到这一点:
import tensorflow as tf
tf.random.set_seed(0)
tmp = tf.random.uniform((3, 4), 0, 10, dtype=tf.int32)
top_k = 2
tf.print(tmp)
# [[3 9 1 7]
# [7 4 0 9]
# [6 6 0 7]]
_, nn_idx = tf.nn.top_k(tmp, top_k)
s = tf.shape(tmp, out_type=nn_idx.dtype)
row_idx = tf.repeat(tf.range(s[0]), top_k)
ones_idx = tf.stack([row_idx, tf.reshape(nn_idx, [-1])], axis=1)
res = tf.scatter_nd(ones_idx, tf.ones(s[0] * top_k, tmp.dtype), s)
tf.print(res)
# [[0 1 0 1]
# [1 0 0 1]
# [1 0 0 1]]
编辑:对于具有图形模式且没有tf.repeat
的1.x版本,以下内容是等效的:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
tf.set_random_seed(0)
tmp = tf.random_uniform((3, 4), 0, 10, dtype=tf.int32)
top_k = 2
_, nn_idx = tf.nn.top_k(tmp, top_k)
s = tf.shape(tmp, out_type=nn_idx.dtype)
row_idx = tf.tile(tf.expand_dims(tf.range(s[0]), 1), (1, top_k))
ones_idx = tf.stack([tf.reshape(row_idx, [-1]),
tf.reshape(nn_idx, [-1])], axis=1)
res = tf.scatter_nd(ones_idx, tf.ones(s[0] * top_k, tmp.dtype), s)
tmp_val, res_val = sess.run((tmp, res))
print(tmp_val)
# [[2 4 8 2]
# [4 8 7 5]
# [1 9 8 8]]
print(res_val)
# [[0 1 1 0]
# [0 1 1 0]
# [0 1 1 0]]