具有非均匀重复的np.repeat()的Tensorflow等价物?

时间:2017-03-22 15:38:45

标签: tensorflow

说我有

a = tf.range(5)
b = tf.convert_to_tensor([3,2,0,1,4])

我有什么方法可以获得

ans = array([0, 0, 0, 1, 1, 3, 4, 4, 4, 4])

如果我做np.repeat(np.arange(5), [3,2,0,1,4])的方式我会这样做?

我已尝试tf.tile(a,b)并重新整形,就像TensorFlow: numpy.repeat() alternative中提到的那样,但这给了我一个形状等级的值误差。

非常感谢!

1 个答案:

答案 0 :(得分:1)

这可以使用Tensorflow中提供的其他操作的组合来完成。 这是在Tensorflow 0.12.0

上测试的
import tensorflow as tf

def repeat(x):
    # get maximum repeat length in x
    maxlen = tf.reduce_max(x)

    # get the length of x
    xlen = tf.shape(x)[0]

    # create a range with the length of x
    rng = tf.range(xlen)

    # tile it to the maximum repeat length, it should be of shape [xlen, maxlen] now
    rng_tiled = tf.tile(tf.expand_dims(rng, 1), tf.pack([1, maxlen]))

    # create a sequence mask using x
    # this will create a boolean matrix of shape [xlen, maxlen]
    # where result[i,j] is true if j < x[i].
    mask = tf.sequence_mask(x, maxlen)

    # mask the elements based on the sequence mask
    return tf.boolean_mask(rng_tiled, mask)


x = tf.placeholder(tf.int32, [None], 'x')
y = repeat(x)

sess = tf.Session()
sess.run(y, {x: [3,2,0,1,4]})

这应输出:array([0, 0, 0, 1, 1, 3, 4, 4, 4, 4])