输入类似:
[1, 3, 2]
所需的输出(在适当的张量中):
[1 0 0
0 1 0
0 1 0
0 1 0
0 0 1
0 0 1]
即,与tf.sequence_mask非常相似(它会给出类似的内容:
[1 1 1
0 1 1
0 1 0]
),但是在先前的序列掩码完成后,每个后续元素都“交错”开始。
非常感谢您的帮助。
答案 0 :(得分:3)
这可以通过以下方式完成:采用大小等于输入元素数量的正方形恒等矩阵,然后对tf.tile()
中的每一行应用inputs[i]
i
次身份矩阵:
import tensorflow as tf
inputs = tf.constant([1, 3, 2])
unit = tf.eye(num_rows=inputs.get_shape().as_list()[0])
unstacked = tf.unstack(unit)
tiled = [tf.tile(u[None, ...], multiples=[inputs[i], 1])
for i, u in enumerate(unstacked)]
res = tf.concat(tiled, axis=0)
with tf.Session() as sess:
print(sess.run(res))
# [[1. 0. 0.]
# [0. 1. 0.]
# [0. 1. 0.]
# [0. 1. 0.]
# [0. 0. 1.]
# [0. 0. 1.]]