使用Numpy stride_tricks获取非重叠的数组块

时间:2011-11-09 19:20:33

标签: python numpy

我正在尝试使用numpy.lib.stride_tricks.as_strided来迭代数组的非重叠块,但是我无法找到参数的文档,所以我只能得到重叠的块

例如,我有一个4x5阵列,我希望得到4个2x2块。我很好,右边和底边的额外细胞被排除在外。

到目前为止,我的代码是:

import sys
import numpy as np

a = np.array([
[1,2,3,4,5],
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20],
])

sz = a.itemsize
h,w = a.shape
bh,bw = 2,2

shape = (h/bh, w/bw, bh, bw)
strides = (w*sz, sz, w*sz, sz)
blocks = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

print blocks[0][0]
assert blocks[0][0].tolist() == [[1, 2], [6,7]]
print blocks[0][1]
assert blocks[0][1].tolist() == [[3,4], [8,9]]
print blocks[1][0]
assert blocks[1][0].tolist() == [[11, 12], [16, 17]]

生成的块数组的形状似乎是正确的,但最后两个断言失败,可能是因为我的形状或步幅参数不正确。我应该设置哪些值来获得非重叠的块?

3 个答案:

答案 0 :(得分:13)

import numpy as np
n=4
m=5
a = np.arange(1,n*m+1).reshape(n,m)
print(a)
# [[ 1  2  3  4  5]
#  [ 6  7  8  9 10]
#  [11 12 13 14 15]
#  [16 17 18 19 20]]
sz = a.itemsize
h,w = a.shape
bh,bw = 2,2
shape = (h/bh, w/bw, bh, bw)
print(shape)
# (2, 2, 2, 2)

strides = sz*np.array([w*bh,bw,w,1])
print(strides)
# [40  8 20  4]

blocks=np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
print(blocks)
# [[[[ 1  2]
#    [ 6  7]]
#   [[ 3  4]
#    [ 8  9]]]
#  [[[11 12]
#    [16 17]]
#   [[13 14]
#    [18 19]]]]

1 a开始(即blocks[0,0,0,0]),到达2(即blocks[0,0,0,1])是一个项目。由于(在我的机器上)a.itemsize是4个字节,所以步幅是1 * 4 = 4.这给了我们strides = (10,2,5,1)*a.itemsize = (40,8,20,4)中的最后一个值。

再次从1开始,到达6(即blocks[0,0,1,0]),距离是5(即w)项,所以步幅为5 * 4 = 20.这说明strides中的倒数第二个值。

再次从1开始,到达3(即blocks[0,1,0,0]),距离是2(即bw)项,所以步幅为2 * 4 = 8.这表示strides中的第二个值。

最后,从1开始,到达11(即blocks[1,0,0,0]),距离是10(即w*bh)项,所以步幅为10 * 4 = 40.所以strides = (40,8,20,4)

答案 1 :(得分:5)

以@ unutbu的答案为例,我编写了一个函数来实现任何ND数组的拼贴技巧。请参阅下面的链接到源。

>>> a = numpy.arange(1,21).reshape(4,5)

>>> print a
[[ 1  2  3  4  5]
 [ 6  7  8  9 10]
 [11 12 13 14 15]
 [16 17 18 19 20]]

>>> blocks = blockwise_view(a, blockshape=(2,2), require_aligned_blocks=False)

>>> print blocks
[[[[ 1 2]
   [ 6 7]]

  [[ 3 4]
   [ 8 9]]]


 [[[11 12]
   [16 17]]

  [[13 14]
   [18 19]]]]

[blockwise_view.py] [test_blockwise_view.py]

答案 2 :(得分:1)

scikit-image有一个名为view_as_blocks()的函数,可以几乎你需要的东西。唯一的问题是它有一个额外的assert禁止你的用例,因为你的块不会均匀地分成你的数组形状。但在您的情况下,assert不是必需的,因此您可以复制function source code并安全地删除自相矛盾的断言。