如何访问ndimage.generic_filter使用的各个值(scipy模块)

时间:2017-10-01 19:36:25

标签: python-2.7 scipy

以下代码使用'generic_filter'方法,从 ndimage scipy python模块),计算3x3元素子矩阵的均值;将每个矩阵元素视为中心元素(并将其排除)并避免边界效应。

import numpy as np
from scipy import ndimage

a = np.reshape(np.arange(25),(5,5))
print a

matrix = np.array(a).astype(np.float)

mask = np.ones((3, 3))
mask[1, 1] = 0

result = ndimage.generic_filter(matrix, np.nanmean, footprint = mask, mode='constant', cval=np.NaN)

print result

结果打印如下:

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]
[[  4.           4.           5.           6.           6.66666667]
 [  5.6          6.           7.           8.           8.4       ]
 [ 10.6         11.          12.          13.          13.4       ]
 [ 15.6         16.          17.          18.          18.4       ]
 [ 17.33333333  18.          19.          20.          20.        ]]

可以证实它按预期工作[第一个意思是(5 + 6 + 1)/ 3 = 4,第二个意思是(0 + 6 + 5 + 7 + 2)/ 5 = 4;等等]。

我的问题是如何访问 ndimage.generic_filter 使用的单个值(子矩阵)来计算每个 np.nanmean

1 个答案:

答案 0 :(得分:1)

ndimage.generic_filter为每个子数组调用一次函数。您将该函数指定为第二个参数。因此,如果您将np.nanmean更改为自定义函数func(请参见下文),那么您可以从func内访问子数组。

要累积可在调用ndimage.generic_filter后访问的子数组列表,您可以将列表作为额外参数传递给func,并将子数组附加到列表中func。然后,您可以访问该列表(及其内容):

import numpy as np
from scipy import ndimage

def func(x, subarrays):
    print(x)
    subarrays.append(x)
    return np.nanmean(x)

a = np.reshape(np.arange(25),(5,5))
matrix = np.array(a).astype(np.float)
mask = np.ones((3, 3))
mask[1, 1] = 0
subarrays = []
result = ndimage.generic_filter(matrix, func, footprint = mask, 
                                mode='constant', cval=np.NaN,
                                extra_arguments=(subarrays,))

print(result)
print(len(subarrays))

打印每个子阵列:

[ nan  nan  nan  nan   1.  nan   5.   6.]
[ nan  nan  nan   0.   2.   5.   6.   7.]
[ nan  nan  nan   1.   3.   6.   7.   8.]
[ nan  nan  nan   2.   4.   7.   8.   9.]
[ nan  nan  nan   3.  nan   8.   9.  nan]
[ nan   0.   1.  nan   6.  nan  10.  11.]
[  0.   1.   2.   5.   7.  10.  11.  12.]
[  1.   2.   3.   6.   8.  11.  12.  13.]
[  2.   3.   4.   7.   9.  12.  13.  14.]
[  3.   4.  nan   8.  nan  13.  14.  nan]
[ nan   5.   6.  nan  11.  nan  15.  16.]
[  5.   6.   7.  10.  12.  15.  16.  17.]
[  6.   7.   8.  11.  13.  16.  17.  18.]
[  7.   8.   9.  12.  14.  17.  18.  19.]
[  8.   9.  nan  13.  nan  18.  19.  nan]
[ nan  10.  11.  nan  16.  nan  20.  21.]
[ 10.  11.  12.  15.  17.  20.  21.  22.]
[ 11.  12.  13.  16.  18.  21.  22.  23.]
[ 12.  13.  14.  17.  19.  22.  23.  24.]
[ 13.  14.  nan  18.  nan  23.  24.  nan]
[ nan  15.  16.  nan  21.  nan  nan  nan]
[ 15.  16.  17.  20.  22.  nan  nan  nan]
[ 16.  17.  18.  21.  23.  nan  nan  nan]
[ 17.  18.  19.  22.  24.  nan  nan  nan]
[ 18.  19.  nan  23.  nan  nan  nan  nan]

并打印最终的result

[[  4.           4.           5.           6.           6.66666667]
 [  5.6          6.           7.           8.           8.4       ]
 [ 10.6         11.          12.          13.          13.4       ]
 [ 15.6         16.          17.          18.          18.4       ]
 [ 17.33333333  18.          19.          20.          20.        ]]

subarrays的长度:

25