我试图以非矢量化和半矢量化的方式计算std。
非矢量化版本的代码运行良好,半矢量化版本也可以使用,但它们生成的结果并不相同。
这是版本1:
import math
#unvectorized version --really slow!
def calc_std_classic(a):
batch = a.shape[0]
channel = a.shape[1]
width = a.shape[2]
height = a.shape[3]
mean = calc_mean_classic2(data_train)
sum = np.zeros((channel))
for i in range(batch):
for j in range(channel):
for w in range(width):
for h in range(height):
sum[j] += (abs(a[i,j,w,h] - mean[j])**2)
var = (sum/(width*height*batch))
return [(math.sqrt(x)) for x in var ]
半矢量化:
def calc_std_classic2(a):
batch = a.shape[0]
channel = a.shape[1]
width = a.shape[2]
height = a.shape[3]
mean = calc_mean_classic2(data_train)
sum = np.zeros((channel))
for i in range(batch):
for j in range(channel):
sum[j] += np.sum(abs(a[i,j,:,:] - mean[j])**2)
var = (sum/(width*height*batch))
return [(math.sqrt(x)) for x in var ]
这是计算平均值的方法,如果需要的话:
def calc_mean_classic2(a):
#sum all elements in each channel and divide by the number of elements
batch = a.shape[0]
channel = a.shape[1]
width = a.shape[2]
height = a.shape[3]
sum = np.zeros((channel))
for i in range(batch):
for j in range(channel):
sum[j] += np.sum(a[i,j,:,:])
return (sum/(width*height*batch))
使用pythons numpy.std()生成的输出和两个方法如下:
std = np.std(data_train, axis=(0,2,3))
std2 = calc_std_classic(data_train)
std3 = calc_std_classic2(data_train)
生成:
std = [ 62.99321928 62.08870764 66.70489964]
std2 = [62.99321927774396, 62.08870764038716, 66.7048996406483]
std3 = [62.99321927813685, 62.088707640014405, 66.70489964063101]
如您所见,这三个都会生成相同的结果,最多8位数。但第三种方法有不同的剩余数字。
我在这里做错了什么?
答案 0 :(得分:1)
浮点算术错误传播有很多好的资源。但是一个直接的问题是numpy.ndarray
显示浮动到python float
的不同精度。因此,为了比较您的结果,您应该转换为相同的数据结构(例如list
s):
>>> print(np.std(arr, ....))
[ 0.28921072 0.2898092 0.28961785]
>>> print(np.std(arr, ....).tolist())
[0.28921072085015914, 0.28980920065339233, 0.28961784922639483]
在你的明确案例中:
calc_std_classic
和calc_std_classic2
之间的差异是因为其中一个使用天真求和a1+a2+....+an
而另一个使用np.sum
。 np.sum
可能是天真的总结,但据我所知它使用pairwise summation。如果您想要更高的准确度,可以实现Kahan summation或使用python-builtin statistics._sum
。
np.std
与您的变体之间的差异难以解释,因为我不知道numpy
使用了什么算法。整整article about "Algorithms for calculating variance" on wikipedia。请注意,任何天真的实现都可能遭遇欠/溢问题,尤其是因为item - mean
减法。
一般建议:
如果你想要它快,那么使用numpy
,如果你想要它具有最高精度,那么使用statistics
。 NumPy主要关注性能方面,因此他们可能不会实现最准确的算法。如果不对算法进行研究,就要避免任何天真的实现,因为它们既不准确也不快。