BatchToSpaceND实际上如何工作?

时间:2019-06-02 09:36:31

标签: python python-3.x tensorflow

我试图弄清楚BatchToSpaceND是如何排列输入矩阵的。示例之一如下:

  

(3)对于形状为[4,2,2,1]和block_size为2的以下输入:

x = [[[[1], [3]], [[9], [11]]],
     [[[2], [4]], [[10], [12]]],
     [[[5], [7]], [[13], [15]]],
     [[[6], [8]], [[14], [16]]]]
     

输出张量的形状为[1、4、4、1]和值:

x = [[[1],   [2],  [3],  [4]],
     [[5],   [6],  [7],  [8]],
     [[9],  [10], [11],  [12]],
     [[13], [14], [15],  [16]]]

有人知道输出张量是如何导出的吗?为什么第一行是[[1], [2], [3], [4]]而不是[[1], [3], [9], [11]]?我也尝试了一些代码:

import tensorflow as tf
sess = tf.InteractiveSession()

a = [[[[1], [3]], [[9], [11]]],
     [[[2], [4]], [[10], [12]]],
     [[[5], [7]], [[13], [15]]],
     [[[6], [8]], [[14], [16]]]]
b = [2, 2, 1, 2, 2, 1]
a = tf.reshape(a, b)

b = [1, 2, 2, 2, 2, 1]
a = tf.reshape(a, b)

b = [1, 4, 4, 1]
a = tf.reshape(a, b)

print(a.eval())

[[[[ 1]
   [ 3]
   [ 9]
   [11]]

  [[ 2]
   [ 4]
   [10]
   [12]]

  [[ 5]
   [ 7]
   [13]
   [15]]

  [[ 6]
   [ 8]
   [14]
   [16]]]]

这不是文档中的结果。

1 个答案:

答案 0 :(得分:2)

让我们考虑the documentation的参数部分:

  

input:一个Tensor。形状为input_shape = [batch] + spatial_shape + remaining_shape的N-D,其中spatial_shape的尺寸为M

因此对于特定示例,这意味着我们具有批处理尺寸4,空间形状(2, 2)和其余形状(1,)。在这里考虑一个真实的例子是有启发性的。让我们将此输入张量视为一批具有1个通道(例如灰度)的四个2x2图像。由于该操作不会修改remaining_shape,因此我们可以忽略它以进行进一步的探索。也就是说,输入实际上包含以下2x2“图像”:

 1   3
 9  11
--------
 2   4
10  12
--------
 5   7
13  15
--------
 6   8
14  16

现在操作需要将批处理维度重塑为空间维度,类似于将大小为a的一维数组batch重塑为a.reshape(-1, *block_shape)。如果我们考虑批次索引[0, 1, 2, 3],它们将被重塑为[[0, 1], [2, 3]](省略新的大小为1的批次尺寸)。实际上,这意味着我们应该拍摄四张2x2图像并将它们并排放置,其中block_shape表示布局,以便创建一张4x4图像。但是,到目前为止,还没有完成,还需要完成另一个步骤,即空间维度是交错的,如文档所示:

  

此操作将“批”尺寸0重塑为形状为M + 1的{​​{1}}尺寸,将这些块交织回由空间尺寸block_shape + [batch]定义的网格中,以获得与输入具有相同等级的结果。

即在我们拥有的网格中布置图像:

[1, ..., M]

现在,我们需要对单个图像的行和列尺寸进行交织以便得出最终结果:

 1   3     2   4
 9  11    10  12

 5   7     6   8
13  15    14  16

哪个给:

        -------⅂
       |       |
    -------⅂   |
   |   |   |   |
   v   v   |   |

 1   3     2   4
                  <---⅂
 9  11    10  12      |
                  <---|---⅂
                      |   |
                      |   |
 5   7     6   8   ---⅃   |
                          |
13  15    14  16   -------⅃

该示例的实际输出具有 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 形状,因为它包含附加的(1, 4, 4, 1)(为方便起见,我们省略了它)并保留了批处理尺寸(1 in这种情况)。

等效代码示例

remaining_shape