以下代码使用'generic_filter'方法,从 ndimage ( scipy python模块),计算3x3元素子矩阵的均值;将每个矩阵元素视为中心元素(并将其排除)并避免边界效应。
import numpy as np
from scipy import ndimage
a = np.reshape(np.arange(25),(5,5))
print a
matrix = np.array(a).astype(np.float)
mask = np.ones((3, 3))
mask[1, 1] = 0
result = ndimage.generic_filter(matrix, np.nanmean, footprint = mask, mode='constant', cval=np.NaN)
print result
结果打印如下:
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]]
[[ 4. 4. 5. 6. 6.66666667]
[ 5.6 6. 7. 8. 8.4 ]
[ 10.6 11. 12. 13. 13.4 ]
[ 15.6 16. 17. 18. 18.4 ]
[ 17.33333333 18. 19. 20. 20. ]]
可以证实它按预期工作[第一个意思是(5 + 6 + 1)/ 3 = 4,第二个意思是(0 + 6 + 5 + 7 + 2)/ 5 = 4;等等]。
我的问题是如何访问 ndimage.generic_filter 使用的单个值(子矩阵)来计算每个 np.nanmean ?
答案 0 :(得分:1)
ndimage.generic_filter
为每个子数组调用一次函数。您将该函数指定为第二个参数。因此,如果您将np.nanmean
更改为自定义函数func
(请参见下文),那么您可以从func
内访问子数组。
要累积可在调用ndimage.generic_filter
后访问的子数组列表,您可以将列表作为额外参数传递给func
,并将子数组附加到列表中func
。然后,您可以访问该列表(及其内容):
import numpy as np
from scipy import ndimage
def func(x, subarrays):
print(x)
subarrays.append(x)
return np.nanmean(x)
a = np.reshape(np.arange(25),(5,5))
matrix = np.array(a).astype(np.float)
mask = np.ones((3, 3))
mask[1, 1] = 0
subarrays = []
result = ndimage.generic_filter(matrix, func, footprint = mask,
mode='constant', cval=np.NaN,
extra_arguments=(subarrays,))
print(result)
print(len(subarrays))
打印每个子阵列:
[ nan nan nan nan 1. nan 5. 6.]
[ nan nan nan 0. 2. 5. 6. 7.]
[ nan nan nan 1. 3. 6. 7. 8.]
[ nan nan nan 2. 4. 7. 8. 9.]
[ nan nan nan 3. nan 8. 9. nan]
[ nan 0. 1. nan 6. nan 10. 11.]
[ 0. 1. 2. 5. 7. 10. 11. 12.]
[ 1. 2. 3. 6. 8. 11. 12. 13.]
[ 2. 3. 4. 7. 9. 12. 13. 14.]
[ 3. 4. nan 8. nan 13. 14. nan]
[ nan 5. 6. nan 11. nan 15. 16.]
[ 5. 6. 7. 10. 12. 15. 16. 17.]
[ 6. 7. 8. 11. 13. 16. 17. 18.]
[ 7. 8. 9. 12. 14. 17. 18. 19.]
[ 8. 9. nan 13. nan 18. 19. nan]
[ nan 10. 11. nan 16. nan 20. 21.]
[ 10. 11. 12. 15. 17. 20. 21. 22.]
[ 11. 12. 13. 16. 18. 21. 22. 23.]
[ 12. 13. 14. 17. 19. 22. 23. 24.]
[ 13. 14. nan 18. nan 23. 24. nan]
[ nan 15. 16. nan 21. nan nan nan]
[ 15. 16. 17. 20. 22. nan nan nan]
[ 16. 17. 18. 21. 23. nan nan nan]
[ 17. 18. 19. 22. 24. nan nan nan]
[ 18. 19. nan 23. nan nan nan nan]
并打印最终的result
:
[[ 4. 4. 5. 6. 6.66666667]
[ 5.6 6. 7. 8. 8.4 ]
[ 10.6 11. 12. 13. 13.4 ]
[ 15.6 16. 17. 18. 18.4 ]
[ 17.33333333 18. 19. 20. 20. ]]
和subarrays
的长度:
25