Julia:如何编写修改结构字段的快速函数?

时间:2017-11-26 22:44:14

标签: matrix julia

我想编写一些有效的方法来处理某些数据结构中的matricies。我测试了两个相同的外部产品功能,一个在普通矩阵上运行,另一个在结构域上运行。第二个函数运行慢约25倍:

mutable struct MyMatrix{T<:Real}
    mtx::Array{T}
    MyMatrix{T}(len) where T<:Real = new(Array{T}(len, len))
end

function outerprod!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
    # mtx = M.mtx - using local reference doesn't help
    len1 = length(x1)
    len2 = length(x2)
    size(M.mtx,1) == len1 && size(M.mtx,2) == len2 || error("length mismatch!")
    for c=1:len2, r=1:len1
        M.mtx[r,c] = x1[r]*x2[c]
    end
    M
end

function outerprod!(mtx::Array{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
    len1 = length(x1)
    len2 = length(x2)
    size(mtx,1) == len1 && size(mtx,2) == len2 || error("length mismatch!")
    for c=1:len2, r=1:len1
        mtx[r,c] = x1[r]*x2[c]
    end
    mtx
end

N = 100;
v1 = collect(Float64, 1:N)
v2 = collect(Float64, N:-1:1)
m = Array{Float64}(100,100)
M = MyMatrix{Float64}(100)

@time outerprod!(M,v1,v2);
>>  0.001334 seconds (10.00 k allocations: 156.406 KiB)

@time outerprod!(m,v1,v2);
>>  0.000055 seconds (4 allocations: 160 bytes)

最后,当我编写第三个版本时,引用快速函数,它在结构上运行得很快:

function outerprod_!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
    outerprod!(M.mtx, x1, x2)
    M
end

@time outerprod_!(M,v1,v2);
>>  0.000058 seconds (4 allocations: 160 bytes)

第一个功能出了什么问题?

P.S。在这个问题上挣扎了一段时间,在julia中寻找不同的优化,最后发现了这个。

1 个答案:

答案 0 :(得分:1)

主要问题是Array{<:Real}不是具体类型:

julia> Array{<:Real}
Array{#s29,N} where N where #s29<:Real

此类型包含任何可能的N,而您对矩阵非常感兴趣,因此它应该是Array{T, 2},或者更容易输入和理解Matrix{T}。另外,请注意您的MyMatrix类型可以是不可变的:在不可变结构中,您不能设置字段,但如果字段本身是可变的,则可以设置其内部字段。此外,for - 循环可以通过@inbounds

获得加速
struct MyMatrix{T<:Real}
    mtx::Matrix{T}
    MyMatrix{T}(len) where T<:Real = new(Array{T}(len, len))
end

function outerprod!(M::MyMatrix{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
    # mtx = M.mtx - using local reference doesn't help
    len1 = length(x1)
    len2 = length(x2)
    size(M.mtx,1) == len1 && size(M.mtx,2) == len2 || error("length mismatch!")
    @inbounds for c=1:len2, r=1:len1
        M.mtx[r,c] = x1[r]*x2[c]
    end
    M
end

function outerprod!(mtx::Array{T}, x1::Vector{T}, x2::Vector{T}) where T<:Real
    len1 = length(x1)
    len2 = length(x2)
    size(mtx,1) == len1 && size(mtx,2) == len2 || error("length mismatch!")
    @inbounds for c=1:len2, r=1:len1
        mtx[r,c] = x1[r]*x2[c]
    end
    mtx
end

N = 100;
v1 = collect(Float64, 1:N)
v2 = collect(Float64, N:-1:1)
m = Matrix{Float64}(100,100)
M = MyMatrix{Float64}(100)

测试速度:

julia> using BenchmarkTools

julia> @btime outerprod!(m,v1,v2);
  2.746 μs (0 allocations: 0 bytes)

julia> @btime outerprod!(M,v1,v2);
  2.746 μs (0 allocations: 0 bytes)