从多维Numpy数组中减去均值

时间:2019-11-28 17:35:51

标签: python arrays numpy numpy-broadcasting

我目前正在学习有关Numpy广播的知识,我正在读这本书(Python for Data Analysis by Wes McKinney,作者提到了以下示例以“贬低”二维数组:

import numpy as np

arr = np.random.randn(4, 3)
print(arr.mean(0))
demeaned = arr - arr.mean(0)
print(demeaned)
print(demeand.mean(0))

有效地导致数组demeaned的平均值为0。

我有个想法将其应用于类似图像的三维数组:

import numpy as np

arr = np.random.randint(0, 256, (400,400,3))
demeaned = arr - arr.mean(2)

这当然失败了,因为根据广播规则,尾随尺寸必须匹配,而事实并非如此:

print(arr.shape)  # (400, 400, 3)
print(arr.mean(2).shape)  # (400, 400)

现在,通过从数组第三维中的每个索引中减去均值,我已经使其大部分工作了:

demeaned = np.ones(arr.shape)

for i in range(3):
    demeaned[...,i] = arr[...,i] - means

print(demeaned.mean(0))

在这一点上,返回值非常接近零,我认为这是一个精度错误。我真的对这个想法是正确的,还是我错过了另一个警告?

此外,这并不是实现我想要实现的最干净,最“ numpy”的方式。有没有我可以利用的功能或原理来改进代码?

2 个答案:

答案 0 :(得分:1)

从numpy版本1.7.0,np.mean和其他几个函数开始,在其axis参数中接受一个元组。这意味着您可以一次在图像平面上执行操作:

m = arr.mean(axis=(0, 1))

此均值将具有(3,)形状,图像的每个平面都有一个元素。

如果要分别减去每个像素的均值,则必须记住广播在右侧边缘对齐了形状元组。这意味着您需要插入一个额外的尺寸:

n = arr.mean(axis=2)
n = n.reshape(*n.shape, 1)

n = arr.mean(axis=2)[..., None]

答案 1 :(得分:1)

尝试np.apply_along_axis()

np.apply_along_axis(lambda x: x - np.mean(x), 2, arr)

输出:您将获得具有相同形状的数组,其中每个单元在所需的维度上均值降低(第二个参数,此处为2)。