我正在尝试编写一个快速坐标下降算法来解决普通最小二乘回归问题。以下Julia代码有效,但我不明白它为什么要分配这么多内存
function OLS_cd{T<:Float64}(A::Array{T,2}, b::Array{T,1}, tolerance::T=1e-12)
N,P = size(A)
x = zeros(P)
r = copy(b)
d = ones(P)
while sum(d.*d) > tolerance
@inbounds for j = 1:P
d[j] = sum(A[:,j].*r)
x[j] += d[j]
r -= d[j]*A[:,j]
end
end
return(x)
end
我用
生成的数据n = 100
p = 75
σ = 0.1
β_nz = float([i*(-1)^i for i in 1:10])
β = append!(β_nz,zeros(p-length(β_nz)))
X = randn(n,p); X .-= mean(X,1); X ./= sqrt(sum(abs2(X),1))
y = X*β + σ*randn(n); y .-= mean(y);
使用@benchmark OLS_cd(X, y)
我
BenchmarkTools.Trial:
memory estimate: 65.94 mb
allocs estimate: 151359
--------------
minimum time: 19.316 ms (16.49% GC)
median time: 20.545 ms (16.60% GC)
mean time: 22.164 ms (16.24% GC)
maximum time: 42.114 ms (10.82% GC)
--------------
samples: 226
evals/sample: 1
time tolerance: 5.00%
memory tolerance: 1.00%
随着p
越来越大,OLS问题变得越来越困难,我注意到随着我的体积变大,需要运行更长时间,Julia分配的内存越多。
为什么每个通过while
循环分配更多内存?在我看来,似乎我的所有操作都已到位,并且类型已明确指定。
在分析时没有任何东西弹出,但如果它有用,我也可以发布该输出。
更新 如下所述,使用矢量化操作引起的临时数组是罪魁祸首。以下消除了无关的分配并且运行得非常快:
function OLS_cd_unrolled{T<:Float64}(A::Array{T,2}, b::Array{T,1}, tolerance::T=1e-12)
N,P = size(A)
x = zeros(P)
r = copy(b)
d = ones(P)
while norm(d,Inf) > tolerance
@inbounds for j = 1:P
d[j] = 0.0; @inbounds for i = 1:N d[j] += A[i,j]*r[i] end
@inbounds for i = 1:N r[i] -= d[j]*A[i,j] end
x[j] += d[j]
end
end
return(x)
end
答案 0 :(得分:5)
A[:,j]
创建副本,而不是视图。您想使用@view A[:,j]
或view(A,:,j)
。
您可以使用r -= d[j]*A[:,j]
对r .= -.(r,d[j]*A[:.j])
进行开发,以摆脱更多临时工。正如@LutfullahTomak所说sum(A[:,j].*r)
应该发展为dot(view(A,:,j),r)
,以摆脱那里的所有临时工。要使用中缀运算符,您可以使用\cdot
,如view(A,:,j)⋅r
。
您应该阅读副本与视图以及矢量化如何导致临时数组。其中的一点是,当矢量化操作发生时,它们必须创建一个新的矢量作为输出。相反,您想要写入现有的向量。 r = ...
表示数组更改引用,因此r = ex
表示创建数组的某个表达式将创建一个新数组,然后将r
指向该数组。 r .= ex
将使用表达式中的值替换数组r
的值。前者分配一个临时,后者则不分配。所有临时工作都来自于这种想法的反复应用。
答案 1 :(得分:2)
实际上,sum(d.*d)
,sum(A[:,j].*r)
等不在位并制作临时数组。首先,sum(d.*d) == dot(d,d)
我认为sum(A[:,j].*r)
会生成2个临时数组。我会为后者做dot(view(A,:,j),r)
。当前稳定版本的julia(0.5)没有r -= d[j]*A[:,j]
的短版本,所以你需要对它进行devectorize做一个循环。