Tensorflow在4d张量中收集特定元素

时间:2017-10-16 16:31:35

标签: python tensorflow

我有一个尺寸为[B,Y,X,N]的4D张量参数,并希望从中选择一个特定的切片n ∈ N,这样我得到的张量就是尺寸[B,Y,X,1](或[B,Y,X])。

特定切片应该是平均包含最高数字的切片;我得到如此的指数:
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)(形状[B])

我使用gathergather_nd尝试了不同的解决方案,但无法使其发挥作用。有很多帖子与此非常相似,但我无法应用其中的一种解决方案。

我正在运行Tensorflow 1.3,因此可以使用gather的新奇轴参数。

1 个答案:

答案 0 :(得分:0)

在下面的示例代码中,输入的形状为[2,3,4,5],结果形状为[2,3,4]

主要观点是:

  • 使用gather_nd很容易获得一行而不是一列,所以我用tf.transpose切换了最后两个维度。
  • 我们需要将tf.argmax(下面indices)的索引转换为final_idx中真正可用的索引(请参阅下面的tf.gather_nd)。转换是通过堆叠三个组件完成的:
    • [0 0 0 1 1 1]
    • [0 1 2 0 1 2]
    • [3 3 3 0 0 0]

所以我们可以从[3, 0]转到

 [[[0 0 3]
   [0 1 3]
   [0 2 3]]
  [[1 0 0]
   [1 1 0]
   [1 2 0]]].
Batch,Y,X = 2, 3, 4
tf.reset_default_graph()

data = np.arange(Batch*Y*X*5)
np.random.shuffle(data)

Params = tf.constant(np.reshape(data, [Batch, Y, X, 5]), dtype=tf.int32)
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)
indices = tf.cast(tf.reshape(tf.tile(tf.reshape(indices, [-1,1]), 
                                     [1,Y]), [-1]), tf.int32)

idx = tf.reshape(tf.range(batch_size), [-1,1])
idx = tf.reshape(tf.tile(idx, [1, y]), [-1])

inc = tf.reshape(tf.tile(tf.range(Y), [Batch]), [-1])
final_idx = tf.reshape(tf.stack([idx, inc, indices], 1), [Batch, Y, -1])
transposed = tf.transpose(Params, [0, 1, 3, 2])
slice = tf.gather_nd(transposed, final_idx)

with tf.Session() as sess:
    print sess.run(Params)
    print sess.run(idx)    
    print sess.run(inc)
    print sess.run(indices)
    print sess.run(final_idx)
    print sess.run(slice)
[[[[ 22  38  68  49 119]
   [ 47  74 111 117  90]
   [ 14  32  31  12  75]
   [ 93  34  57   3  56]]

  [[ 69  21   4  94  39]
   [ 83  96  62 102  80]
   [ 55 113  48  98  29]
   [107  81  67  76  28]]

  [[ 53  51  77  66  63]
   [ 92 115 118 116  13]
   [ 43  78  15   1   0]
   [ 99  50  27  60  73]]]


 [[[ 97  88  91  64  86]
   [ 72 110  26  87  33]
   [ 70  30  41 114   5]
   [ 95  82  46  16  61]]

  [[109  71  45   8  40]
   [101   9  23  59  10]
   [ 37  65  44  11  19]
   [ 42 104 106 105  18]]

  [[112  58   7  17  89]
   [ 25  79 103  85  20]
   [ 35   6 108 100  36]
   [ 24  52   2  54  84]]]]

[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]

[[[0 0 3]
  [0 1 3]
  [0 2 3]]

 [[1 0 0]
  [1 1 0]
  [1 2 0]]]

[[[ 49 117  12   3]
  [ 94 102  98  76]
  [ 66 116   1  60]]

 [[ 97  72  70  95]
  [109 101  37  42]
  [112  25  35  24]]]