我想编写一些有效的方法来处理某些数据结构中的matricies。我测试了两个相同的外部产品功能,一个在普通矩阵上运行,另一个在结构域上运行。第二个函数运行慢约25倍:
mutable struct MyMatrix{T<:Real}
mtx::Array{T}
MyMatrix{T}(len) where T<:Real = new(Array{T}(len, len))
end
function outerprod!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
# mtx = M.mtx - using local reference doesn't help
len1 = length(x1)
len2 = length(x2)
size(M.mtx,1) == len1 && size(M.mtx,2) == len2 || error("length mismatch!")
for c=1:len2, r=1:len1
M.mtx[r,c] = x1[r]*x2[c]
end
M
end
function outerprod!(mtx::Array{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
len1 = length(x1)
len2 = length(x2)
size(mtx,1) == len1 && size(mtx,2) == len2 || error("length mismatch!")
for c=1:len2, r=1:len1
mtx[r,c] = x1[r]*x2[c]
end
mtx
end
N = 100;
v1 = collect(Float64, 1:N)
v2 = collect(Float64, N:-1:1)
m = Array{Float64}(100,100)
M = MyMatrix{Float64}(100)
@time outerprod!(M,v1,v2);
>> 0.001334 seconds (10.00 k allocations: 156.406 KiB)
@time outerprod!(m,v1,v2);
>> 0.000055 seconds (4 allocations: 160 bytes)
最后,当我编写第三个版本时,引用快速函数,它在结构上运行得很快:
function outerprod_!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
outerprod!(M.mtx, x1, x2)
M
end
@time outerprod_!(M,v1,v2);
>> 0.000058 seconds (4 allocations: 160 bytes)
第一个功能出了什么问题?
P.S。在这个问题上挣扎了一段时间,在julia中寻找不同的优化,最后发现了这个。
答案 0 :(得分:1)
主要问题是Array{<:Real}
不是具体类型:
julia> Array{<:Real}
Array{#s29,N} where N where #s29<:Real
此类型包含任何可能的N
,而您对矩阵非常感兴趣,因此它应该是Array{T, 2}
,或者更容易输入和理解Matrix{T}
。另外,请注意您的MyMatrix
类型可以是不可变的:在不可变结构中,您不能设置字段,但如果字段本身是可变的,则可以设置其内部字段。此外,for
- 循环可以通过@inbounds
:
struct MyMatrix{T<:Real}
mtx::Matrix{T}
MyMatrix{T}(len) where T<:Real = new(Array{T}(len, len))
end
function outerprod!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
# mtx = M.mtx - using local reference doesn't help
len1 = length(x1)
len2 = length(x2)
size(M.mtx,1) == len1 && size(M.mtx,2) == len2 || error("length mismatch!")
@inbounds for c=1:len2, r=1:len1
M.mtx[r,c] = x1[r]*x2[c]
end
M
end
function outerprod!(mtx::Array{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
len1 = length(x1)
len2 = length(x2)
size(mtx,1) == len1 && size(mtx,2) == len2 || error("length mismatch!")
@inbounds for c=1:len2, r=1:len1
mtx[r,c] = x1[r]*x2[c]
end
mtx
end
N = 100;
v1 = collect(Float64, 1:N)
v2 = collect(Float64, N:-1:1)
m = Matrix{Float64}(100,100)
M = MyMatrix{Float64}(100)
测试速度:
julia> using BenchmarkTools
julia> @btime outerprod!(m,v1,v2);
2.746 μs (0 allocations: 0 bytes)
julia> @btime outerprod!(M,v1,v2);
2.746 μs (0 allocations: 0 bytes)