定义一个表现为堆叠数组的类型

时间:2015-09-24 15:37:56

标签: julia

我有一个由多个数组组成的类型

type VectorWrapper{T}
   x::Vector{T}
   y::Vector{T}
end

我想要矢量函数(让我们说sumabs2normscale!copy!fill!axpy!map!reduce)对此类型采取行动,就好像它是一个堆叠的向量。例如,我想要以下内容:

sumabs2(a) <-> sumabs2(a.x) + sumabs2(a.y)
copy!(a1, a2) <->  copy!(a1.x, a2.x) ; copy !(a1.y, a2.y)

我看到两个解决方案:

  1. 我可以逐个定义每个函数,但这是重复的。
  2. 我还可以定义sizegetindex函数

    type VectorWrapper{T} <: AbstractVector{T}
       x::Vector{T}
       y::Vector{T}
    end
    Base.getindex(a::VectorWrapper, i::Integer) = i <= length(a.x) ? a.x[i] : a.y[i-length(a.x)]
    Base.size(a::VectorWrapper) = map(+, size(a.x), size(a.y))
    

    但这不符合要求:

    a = VectorWrapper(rand(10_000_000), rand(10_000_000))
    @time sumabs2(a)
    # 0.091090 seconds (7 allocations: 208 bytes)
    @time sumabs2(a.x) + sumabs2(a.y)
    # 0.010433 seconds (7 allocations: 208 bytes)
    

    由于sumabs2(a)中添加的操作,绑定检查以及缺少SIMD矢量化,我猜sumabs2(a.x) + sumabs2(a.y)getindex慢。

  3. 是否有一种解决方案将1的性能与2的简洁性相结合?

2 个答案:

答案 0 :(得分:2)

size目前可能是一个性能陷阱,您是否尝试过编写{{1}}方法而没有它?

另外,你见过https://github.com/tanmaykm/ChainedVectors.jl吗?似乎它已经可能正在做你想要的。

答案 1 :(得分:1)

如何将堆叠数组v定义为主要组件,并将xy作为其数组视图...?例如:

type VecWrap{T} <: AbstractVector{T}
    v::Vector{T}
    x::Vector{T}
    y::Vector{T}

    function VecWrap{T}( x_in::Vector{T}, y_in::Vector{T} )
        ( nx, ny ) = ( length( x_in ), length( y_in ) )
        v = Vector{T}( nx + ny )
        v[ 1      : nx      ] = x_in
        v[ (nx+1) : (nx+ny) ] = y_in

        x = pointer_to_array( pointer( v ),       (nx,) )
        y = pointer_to_array( pointer( v, nx+1 ), (ny,) )
        return new( v, x, y )
    end
end

Base.getindex( a::VecWrap, i::Int ) = a.v[ i ]
Base.setindex!( a::VecWrap, val, i::Int ) = ( a.v[ i ] = val )
Base.size( a::VecWrap ) = size( a.v )
Base.copy( a::VecWrap ) = VecWrap{Float64}( a.x, a.y )
Base.copy!( b::VecWrap, a::VecWrap ) = copy!( b.v, a.v )

function test()
    n = 10_000_000
    a = VecWrap{Float64}( rand( n ), rand( n ) )
    for loop = 1:3
        println( "loop = $loop" )
        @time sumabs2( a )
        @time sumabs2( a.x ) + sumabs2( a.y )
        @time sumabs2( a.v )
    end
end

在我的电脑上,结果是

loop = 1
  0.012153 seconds
  0.009812 seconds
  0.009667 seconds
loop = 2
  0.011365 seconds
  0.009657 seconds
  0.009641 seconds
loop = 3
  0.011350 seconds
  0.009658 seconds
  0.009665 seconds

和填充!()等似乎没问题(虽然没有完全确认)。将x和y定义为SubArray似乎也具有几乎相同的效率(0.009-0.011秒)。