如何在张量流中重复张量元素可变次数

时间:2019-10-09 05:51:57

标签: tensorflow

我有一个输入张量,表示项目和项目数量之间的交替:

[item0, qty0, item1, qty1, ...]

我想将这个张量展开为

[[item0*qty0], [item1*qty1], ...]

示例:

[1000, 2, 3000, 5, ...]

[[100,1000], [3000,3000,3000,3000,3000], ...]

这在tensorflow 1.x btw中。

2 个答案:

答案 0 :(得分:1)

在tf1.x版本中。

import tensorflow as tf
inputs = tf.constant([10,2,20,3,30,4])
x_unpacked = tf.unstack(tf.reshape(inputs,(-1,2)))
tmp = []
for t in x_unpacked:
    tmp.append(tf.tile([t[0]], [t[1]]))
ans = tf.concat(tmp, axis=0)

with tf.Session() as sess:
    print(sess.run(ans))
# [10 10 20 20 20 30 30 30 30]

在tf2.x中,它可以是一行,

tf.concat([tf.tile([x[0]],[x[1]]) for x in tf.reshape(inputs, (-1,2))], axis=0)
# [10 10 20 20 20 30 30 30 30]

答案 1 :(得分:0)

我从zihaozhihao给出的答案开始,但是由于输入具有维(?),长度是未知的,因此tf.unstack无法使用。

但是,建议使用map_fn的错误消息之一似乎有效:

x_unpacked = tf.reshape(input, (-1, 2))
tiled = tf.map_fn(lambda x: tf.tile([x[0]], [x[1]]), x_unpacked)