以二维张量y [i]复制行,其中i是另一个张量y的索引?

时间:2019-05-07 15:04:31

标签: tensorflow

我正在寻找一个tf操作,该操作将输入张量x中的元素复制y [i]倍,其中i是第二张量中的索引。更准确地说,该操作应达到以下目的:

x = tf.constant([[1, 4], [2, 5], [3, 6]])
y = tf.constant([3, 2, 4])

z = <operation>(x, y) # [[1, 4], [1, 4], [1, 4],
                         [2, 5], [2, 5], 
                         [3, 6], [3, 6], [3, 6], [3, 6]]

我可以使用什么操作?谢谢:)

1 个答案:

答案 0 :(得分:1)

关键思想是构建根据y复制的索引的一维张量,然后执行tf.gather

def repeat(t, times):
    num_elements = tf.shape(t)[0]

    def cond_fn(i, _):
        return i < num_elements

    def body_fn(i, indices_ta):
        repeated_i = tf.tile(i[tf.newaxis], times[i, tf.newaxis])
        return (i + 1, indices_ta.write(i, repeated_i))

    indices_ta = tf.TensorArray(times.dtype, num_elements, infer_shape=False)
    _, indices_ta = tf.while_loop(
        cond_fn,
        body_fn,
        loop_vars=(0, indices_ta))

    return tf.gather(t, indices_ta.concat())