tf批量收集具有不同索引的每个元素

时间:2019-05-29 08:17:07

标签: python tensorflow

如何通过不同的索引从小型批处理中gather使用不同的元素?
例如 以下(mini_batch=3)张量:

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]
  [12 13 14 15]]

 [[16 17 18 19]
  [20 21 22 23]
  [24 25 26 27]
  [28 29 30 31]]

 [[32 33 34 35]
  [36 37 38 39]
  [40 41 42 43]
  [44 45 46 47]]]

带有索引

[[[0 0]
  [2 0]
  [1 0]
  [0 2]
  [2 2]
  [1 2]
  [0 1]
  [2 1]
  [1 1]]

 [[0 0]
  [3 0]
  [1 0]
  [0 3]
  [2 3]
  [1 3]
  [0 1]
  [3 1]
  [1 1]]

 [[0 0]
  [3 0]
  [1 0]
  [0 3]
  [2 3]
  [1 3]
  [0 1]
  [3 1]
  [1 1]]]

我希望结果是

[[[ 0  1    3]
  [ 4  5    7]
  [12 13   15]]

 [[16 17  19]
  [20 21  23]
  [28 29  31]]

 [[32 33 34 ]
  [36 37 38 ]
  [40 41 42 ]]

但是使用以下代码:

tf.gather_nd(batch_input,batch_indecies)

我得到以下输出

indices[1,8] = [3, 3] does not index into param shape [3,4,4] [Op:GatherNd]

根据我的理解,对于索引[[[0,0]]],tf将获得第一批第一行,即[0,1,2,3],而我却想从第一批中获取[0]

这是我设法做到的方式(我仍然切片并堆叠,并将尾随的批处理索引连接到每个索引,因此我认为有更好的解决方案了):

def gather_by_batch(tensor, batch_indecies):  # e.g shapes (4,10,10) (4,3,3) 
    #working
    per_dim_shape = tf.shape(batch_indecies)[1]
    stacked_indecies = tf.split(batch_indecies, FLAGS.batch_size)

    final_indecies =[]
    for idx,indecies in enumerate(stacked_indecies):
        final_indecies.append(tf.concat([tf.broadcast_to([idx], [per_dim_shape*per_dim_shape, 1]), indecies], axis=1))

    return tf.reshape(tf.gather_nd(tensor, final_indecies), (-1, per_dim_shape, per_dim_shape))

0 个答案:

没有答案