Tensorflow:索引的“交错”序列掩码?

时间:2019-05-08 19:32:29

标签: python tensorflow

输入类似:

[1, 3, 2]

所需的输出(在适当的张量中):

[1 0 0 
 0 1 0
 0 1 0
 0 1 0
 0 0 1
 0 0 1]

即,与tf.sequence_mask非常相似(它会给出类似的内容:

[1 1 1
 0 1 1
 0 1 0]

),但是在先前的序列掩码完成后,每个后续元素都“交错”开始。

非常感谢您的帮助。

1 个答案:

答案 0 :(得分:3)

这可以通过以下方式完成:采用大小等于输入元素数量的正方形恒等矩阵,然后对tf.tile()中的每一行应用inputs[i] i次身份矩阵:

import tensorflow as tf

inputs = tf.constant([1, 3, 2])

unit = tf.eye(num_rows=inputs.get_shape().as_list()[0])
unstacked = tf.unstack(unit)
tiled = [tf.tile(u[None, ...], multiples=[inputs[i], 1])
         for i, u in enumerate(unstacked)]
res = tf.concat(tiled, axis=0)

with tf.Session() as sess:
    print(sess.run(res))
# [[1. 0. 0.]
#  [0. 1. 0.]
#  [0. 1. 0.]
#  [0. 1. 0.]
#  [0. 0. 1.]
#  [0. 0. 1.]]