我有一个输入张量,表示项目和项目数量之间的交替:
[item0, qty0, item1, qty1, ...]
我想将这个张量展开为
[[item0*qty0], [item1*qty1], ...]
示例:
[1000, 2, 3000, 5, ...]
[[100,1000], [3000,3000,3000,3000,3000], ...]
这在tensorflow 1.x btw中。
答案 0 :(得分:1)
在tf1.x版本中。
import tensorflow as tf
inputs = tf.constant([10,2,20,3,30,4])
x_unpacked = tf.unstack(tf.reshape(inputs,(-1,2)))
tmp = []
for t in x_unpacked:
tmp.append(tf.tile([t[0]], [t[1]]))
ans = tf.concat(tmp, axis=0)
with tf.Session() as sess:
print(sess.run(ans))
# [10 10 20 20 20 30 30 30 30]
在tf2.x中,它可以是一行,
tf.concat([tf.tile([x[0]],[x[1]]) for x in tf.reshape(inputs, (-1,2))], axis=0)
# [10 10 20 20 20 30 30 30 30]
答案 1 :(得分:0)
我从zihaozhihao给出的答案开始,但是由于输入具有维(?),长度是未知的,因此tf.unstack无法使用。
但是,建议使用map_fn的错误消息之一似乎有效:
x_unpacked = tf.reshape(input, (-1, 2))
tiled = tf.map_fn(lambda x: tf.tile([x[0]], [x[1]]), x_unpacked)