如何通过索引为张量分配值?

时间:2020-09-01 14:03:11

标签: python tensorflow

我有一个2D张量,形状为(M,N)。我想得到一个掩码,对于每一行,给定张量的top-k为1,其他为0。 例如,张量为:

[[1,3,5,7],
 [2,4,7,0],
 [9,3,1,5]]

如果将topk设置为1,则掩码应为:

[[0,0,0,1],
 [0,0,1,0],
 [1,0,0,0]]

如果将topk设置为2,则掩码应为:

[[0,0,1,1],
 [0,1,1,0],
 [1,0,0,1]]

我想出一种非常繁琐的方法:

_, nn_idx = tf.nn.top_k(tmp, top_k)  # the shape of tmp is (M,N) and the shape of nn_idx is (M,top_k)
nn_idx_one = tf.reshape(nn_idx, [-1, 1])
nn_idx_multi_hot = tf.one_hot(nn_idx_one, depth=N)  # (M*top_k,1) -> (M*top_k,N), N is about 100000 (very big)
nn_idx_multi_hot = tf.reshape(nn_idx_multi_hot, [-1, top_k, N])  # (M,top_k,N)
nn_idx_multi_hot = tf.reduce_sum(nn_idx_multi_hot, axis=1)  # (M,N) 
mask_a = tf.ones(shape=[M, N])
mask_b = tf.zeros(shape=[M, N])
mask = tf.where(nn_idx_multi_hot > 0, mask_a , mask_b )  # the target mask

此操作占用大量内存,我想有一种简洁的方法,有人可以帮助我吗?

1 个答案:

答案 0 :(得分:0)

您可以使用tf.scatter_nd来做到这一点:

import tensorflow as tf

tf.random.set_seed(0)
tmp = tf.random.uniform((3, 4), 0, 10, dtype=tf.int32)
top_k = 2
tf.print(tmp)
# [[3 9 1 7]
#  [7 4 0 9]
#  [6 6 0 7]]
_, nn_idx = tf.nn.top_k(tmp, top_k)
s = tf.shape(tmp, out_type=nn_idx.dtype)
row_idx = tf.repeat(tf.range(s[0]), top_k)
ones_idx = tf.stack([row_idx, tf.reshape(nn_idx, [-1])], axis=1)
res = tf.scatter_nd(ones_idx, tf.ones(s[0] * top_k, tmp.dtype), s)
tf.print(res)
# [[0 1 0 1]
#  [1 0 0 1]
#  [1 0 0 1]]

编辑:对于具有图形模式且没有tf.repeat的1.x版本,以下内容是等效的:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    tf.set_random_seed(0)
    tmp = tf.random_uniform((3, 4), 0, 10, dtype=tf.int32)
    top_k = 2
    _, nn_idx = tf.nn.top_k(tmp, top_k)
    s = tf.shape(tmp, out_type=nn_idx.dtype)
    row_idx = tf.tile(tf.expand_dims(tf.range(s[0]), 1), (1, top_k))
    ones_idx = tf.stack([tf.reshape(row_idx, [-1]),
                         tf.reshape(nn_idx, [-1])], axis=1)
    res = tf.scatter_nd(ones_idx, tf.ones(s[0] * top_k, tmp.dtype), s)
    tmp_val, res_val = sess.run((tmp, res))
    print(tmp_val)
    # [[2 4 8 2]
    #  [4 8 7 5]
    #  [1 9 8 8]]
    print(res_val)
    # [[0 1 1 0]
    #  [0 1 1 0]
    #  [0 1 1 0]]