我有两个问题:
让张量T
具有形状[n1, n2, n3, n4]
。让另一个类型为IDX
的形状[n1, n2]
的张量int
包含所需的索引。我如何获得形状张量[n1, n2, n4]
,我只想提取n3
dim T
dim中的IDX
指数x = [[[2, 3, 1, 2, 5],
[7, 1, 5, 6, 0],
[7, 8, 1, 3, 8]],
[[0, 7, 7, 6, 9],
[5, 6, 7, 8, 8],
[2, 3, 2, 9, 6]]]
idx = [[1, 0, 2],
[4, 3, 3]]
res = [[3, 7, 1],
[9, 8, 9]]`
。简单的例子:
:nth-of-class
提前感谢您的帮助!
答案 0 :(得分:1)
我使用tf.gather_nd
处理问题1。
输入是:
x
:你的张量T
从中提取形状[n1, n2, n3, n4]
的值
size(T)
idx
:您要从T
中提取的形状[n1, n2]
且包含0
到n3 - 1
结果是:
res
:T
中idx
的{{1}}的提取值,[n1, n2, n4]
由于tf.gather_nd()
期望您创建要在x
中检索的整个索引(例如[1, 0, 4, 1]
),我们必须先在indices_base
中创建它。
论证indices
需要具有res + R
形状,即[n1, n2, n4, R]
,其中R=4
是张量x
的等级。
# Inputs:
n1 = 2
n2 = 3
n3 = 5
n4 = 2
x = tf.reshape(tf.range(n1*n2*n3*n4), [n1, n2, n3, n4]) # range(60) reshaped
idx = tf.constant([[1, 0, 2], [4, 3, 3]]) # shape [n1, n2]
range_n1 = tf.reshape(tf.range(n1), [n1, 1, 1, 1])
indices_base_1 = tf.tile(range_n1, [1, n2, n4, 1])
range_n2 = tf.reshape(tf.range(n2), [1, n2, 1, 1])
indices_base_2 = tf.tile(range_n2, [n1, 1, n4, 1])
range_n4 = tf.reshape(tf.range(n4), [1, 1, n4, 1])
indices_base_4 = tf.tile(range_n4, [n1, n2, 1, 1])
idx = tf.reshape(idx, [n1, n2, 1, 1])
idx = tf.tile(idx, [1, 1, n4, 1])
# Create the big indices needed of shape [n1, n2, n3, n4]
indices = tf.concat(3, [indices_base_1, indices_base_2, idx, indices_base_4])
# Finally we can apply tf.gather_nd
res = tf.gather_nd(x, indices)
无论如何,这是非常复杂的,我不确定它是否能产生良好的性能。
P.S:你应该在一个单独的帖子中发布问题2。