如何通过不同的索引从小型批处理中gather使用不同的元素?
例如
以下(mini_batch=3
)张量:
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]]
[[16 17 18 19]
[20 21 22 23]
[24 25 26 27]
[28 29 30 31]]
[[32 33 34 35]
[36 37 38 39]
[40 41 42 43]
[44 45 46 47]]]
带有索引
[[[0 0]
[2 0]
[1 0]
[0 2]
[2 2]
[1 2]
[0 1]
[2 1]
[1 1]]
[[0 0]
[3 0]
[1 0]
[0 3]
[2 3]
[1 3]
[0 1]
[3 1]
[1 1]]
[[0 0]
[3 0]
[1 0]
[0 3]
[2 3]
[1 3]
[0 1]
[3 1]
[1 1]]]
我希望结果是
[[[ 0 1 3]
[ 4 5 7]
[12 13 15]]
[[16 17 19]
[20 21 23]
[28 29 31]]
[[32 33 34 ]
[36 37 38 ]
[40 41 42 ]]
但是使用以下代码:
tf.gather_nd(batch_input,batch_indecies)
我得到以下输出
indices[1,8] = [3, 3] does not index into param shape [3,4,4] [Op:GatherNd]
根据我的理解,对于索引[[[0,0]]]
,tf将获得第一批第一行,即[0,1,2,3]
,而我却想从第一批中获取[0]
。
这是我设法做到的方式(我仍然切片并堆叠,并将尾随的批处理索引连接到每个索引,因此我认为有更好的解决方案了):
def gather_by_batch(tensor, batch_indecies): # e.g shapes (4,10,10) (4,3,3)
#working
per_dim_shape = tf.shape(batch_indecies)[1]
stacked_indecies = tf.split(batch_indecies, FLAGS.batch_size)
final_indecies =[]
for idx,indecies in enumerate(stacked_indecies):
final_indecies.append(tf.concat([tf.broadcast_to([idx], [per_dim_shape*per_dim_shape, 1]), indecies], axis=1))
return tf.reshape(tf.gather_nd(tensor, final_indecies), (-1, per_dim_shape, per_dim_shape))