如何用numpy实现tf.space_to_depth?

时间:2017-06-04 19:27:15

标签: python numpy tensorflow linear-algebra

这是tensorflow中名为tf.space_to_depth的函数。在Tensorflow源代码中实现此功能对我来说非常困难。你能帮我用numpy来实现吗?

以下是一些可视化此功能的代码。顺便说一句,在所有事情之前,最好提一下,tensorflow功能的输入应该有输入形状:[batch, height, width, depth]

假设这段代码。首先,我们需要定义一个张量:

norm = tf.reshape(tf.range(0,72),(1,6,6,2))

以下是深度1(norm[0,:,:,0])的值:

[[ 0,  2,  4,  6,  8, 10],
 [12, 14, 16, 18, 20, 22],
 [24, 26, 28, 30, 32, 34],
 [36, 38, 40, 42, 44, 46],
 [48, 50, 52, 54, 56, 58],
 [60, 62, 64, 66, 68, 70]]

这是深度2(norm[0,:,:,1])的值:

[[ 1,  3,  5,  7,  9, 11],
 [13, 15, 17, 19, 21, 23],
 [25, 27, 29, 31, 33, 35],
 [37, 39, 41, 43, 45, 47],
 [49, 51, 53, 55, 57, 59],
 [61, 63, 65, 67, 69, 71]]

在下一步中,我想应用tf.space_to_depth函数,这里是:

trans = tf.space_to_depth(norm,2)

输出形状为:(1,3,3,8),这是此函数的输出:

trans[0,:,:,0]
[[ 0,  4,  8],
 [24, 28, 32],
 [48, 52, 56]]
trans[0,:,:,1]
[[ 1,  5,  9],
 [25, 29, 33],
 [49, 53, 57]]
trans[0,:,:,2]
[[ 2,  6, 10],
 [26, 30, 34],
 [50, 54, 58]]
trans[0,:,:,3]
[[ 3,  7, 11],
 [27, 31, 35],
 [51, 55, 59]]
trans[0,:,:,4]
[[12, 16, 20],
 [36, 40, 44],
 [60, 64, 68]]
trans[0,:,:,5]
[[13, 17, 21],
 [37, 41, 45],
 [61, 65, 69]]
trans[0,:,:,6]
[[14, 18, 22],
 [38, 42, 46],
 [62, 66, 70]]
trans[0,:,:,7]
[[15, 19, 23],
 [39, 43, 47],
 [63, 67, 71]]

有人可以帮助我如何在numpy中实现这个函数的矢量化版本?

提前感谢任何回复!

1 个答案:

答案 0 :(得分:2)

您可以通过适当调用space_to_depthreshape()函数来实施swapaxes()

import numpy as np

def space_to_depth(x, block_size):
    x = np.asarray(x)
    batch, height, width, depth = x.shape
    reduced_height = height // block_size
    reduced_width = width // block_size
    y = x.reshape(batch, reduced_height, block_size,
                         reduced_width, block_size, depth)
    z = np.swapaxes(y, 2, 3).reshape(batch, reduced_height, reduced_width, -1)
    return z

以下是the documentation of tf.space_to_depth的示例:

In [328]: x = [[[[1], [2]],
     ...:       [[3], [4]]]]
     ...: 

In [329]: space_to_depth(x, 2)
Out[329]: array([[[[1, 2, 3, 4]]]])

In [330]: x = [[[[1, 2, 3], [4, 5, 6]],
     ...:       [[7, 8, 9], [10, 11, 12]]]]
     ...: 

In [331]: space_to_depth(x, 2)
Out[331]: array([[[[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]]]])

In [332]: x = [[[[1],   [2],  [5],  [6]],
     ...:       [[3],   [4],  [7],  [8]],
     ...:       [[9],  [10], [13],  [14]],
     ...:       [[11], [12], [15],  [16]]]]
     ...: 

In [333]: space_to_depth(x, 2)
Out[333]: 
array([[[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12],
         [13, 14, 15, 16]]]])

这是你的榜样:

In [334]: norm = np.arange(72).reshape(1, 6, 6, 2)

In [335]: trans = space_to_depth(norm, 2)

In [336]: trans[0, :, :, 0]
Out[336]: 
array([[ 0,  4,  8],
       [24, 28, 32],
       [48, 52, 56]])

In [337]: trans[0, :, :, 1]
Out[337]: 
array([[ 1,  5,  9],
       [25, 29, 33],
       [49, 53, 57]])

In [338]: trans[0, :, :, 7]
Out[338]: 
array([[15, 19, 23],
       [39, 43, 47],
       [63, 67, 71]])