我希望将特定数字重复不同的次数,如下所示:
x = np.array([0,1,2])
np.repeat(x,[3,4,5])
>>> array([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])
(0重复3次,1、4次,等等。)
这个答案(https://stackoverflow.com/a/35367161/2530674)似乎暗示我可以结合使用tf.tile
和tf.reshape
来达到相同的效果。但是,我相信只有重复次数是固定的情况。
如何在Tensorflow中获得相同的效果?
edit1:很遗憾,没有tf.repeat
。
答案 0 :(得分:3)
这是解决问题的一种“强力方法”,只需将每个值重复最多重复的次数即可,然后选择正确的元素:
import tensorflow as tf
# Repeats across the first dimension
def tf_repeat(arr, repeats):
arr = tf.expand_dims(arr, 1)
max_repeats = tf.reduce_max(repeats)
tile_repeats = tf.concat(
[[1], [max_repeats], tf.ones([tf.rank(arr) - 2], dtype=tf.int32)], axis=0)
arr_tiled = tf.tile(arr, tile_repeats)
mask = tf.less(tf.range(max_repeats), tf.expand_dims(repeats, 1))
result = tf.boolean_mask(arr_tiled, mask)
return result
with tf.Graph().as_default(), tf.Session() as sess:
print(sess.run(tf_repeat([0, 1, 2], [3, 4, 5])))
输出:
[0 0 0 1 1 1 1 2 2 2 2 2]