在System.Numerics.Vector <t>中使用F#计量单位

时间:2016-01-02 17:31:30

标签: .net f# vectorization simd units-of-measurement

我很难将F#度量单位与System.Numerics.Vector<'T>类型结合使用。让我们看一下玩具问题:假设我们有一个xs类型的数组float<m>[],由于某种原因,我们想要对其所有组件进行平方,得到一个类型为float<m^2>[]的数组。这与标量代码非常吻合:

xs |> Array.map (fun x -> x * x) // float<m^2>[]

现在假设我们想通过使用SIMD在System.Numerics.Vector<float>.Count大小的块中执行乘法来对此操作进行向量化,例如:

open System.Numerics
let simdWidth = Vector<float>.Count
// fill with dummy data
let xs = Array.init (simdWidth * 10) (fun i -> float i * 1.0<m>)
// array to store the results
let rs: float<m^2> array = Array.zeroCreate (xs |> Array.length)
// number of SIMD operations required
let chunks = (xs |> Array.length) / simdWidth
// for simplicity, assume xs.Length % simdWidth = 0
for i = 0 to chunks - 1 do
    let v = Vector<_>(xs, i * simdWidth) // Vector<float<m>>, containing xs.[i .. i+simdWidth-1]
    let u = v * v                        // Vector<float<m>>; expected: Vector<float<m^2>>
    u.CopyTo(rs, i * simdWidth)          // units mismatch

我相信我理解为什么会发生这种情况:F#编译器如何知道,System.Numerics.Vector<'T>.op_Multiply做什么以及适用的算术规则是什么?它实际上可以是任何操作。那么它应该如何推断出正确的单位?

问题是:最好的方法是什么?我们如何告诉编译器适用哪些规则?

尝试1 :从xs删除所有度量单位信息,稍后再将其添加回来:

// remove all UoM from all arrays
let xsWoM = Array.map (fun x -> x / 1.0<m>) xs
// ...
// perform computation using xsWoM etc.
// ...
// add back units again
let xs = Array.map (fun x -> x * 1.0<m>) xsWoM

问题:执行不必要的计算和/或复制操作,由于性能原因而无法对代码进行矢量化。此外,很大程度上违背了使用UoM的目的。

尝试2 :使用内联IL更改Vector<'T>.op_Multiply的返回类型:

// reinterpret x to be of type 'b
let inline retype (x: 'a) : 'b = (# "" x: 'b #)
let inline (.*.) (u: Vector<float<'m>>) (v: Vector<float<'m>>): Vector<float<'m^2>> = u * v |> retype
// ...
let u = v .*. v // asserts type Vector<float<m^2>>

问题:不需要任何其他操作,但使用不推荐使用的功能(内联IL),并且不完全通用(仅针对度量单位)。

有没有人有更好的解决方案*

*请注意,上述示例确实是一个玩具问题,以证明一般问题。真正的程序解决了更复杂的初始值问题,涉及多种物理量。

2 个答案:

答案 0 :(得分:1)

编译器可以搞清楚如何应用单位规则进行乘法,这里的问题是你有一个包装类型。在您的第一个示例中,当您编写xs |> Array.map (fun x -> x * x)时,您将根据数组的元素而不是直接在数组中描述乘法。

当你有Vector<float<m>>时,单位会附加到float而不是Vector,所以当你尝试乘以Vector时,编译器不会将该类型视为具有任何单位。

鉴于该类公开的方法,我认为没有直接使用Vector<'T>的简单解决方法,但有包装类型的选项。

这样的东西可以给你一个单元友好的矢量:

type VectorWithUnits<'a, [<Measure>]'b> = 
    |VectorWithUnits of Vector<'a>

    static member inline (*) (a : VectorWithUnits<'a0, 'b0>, b : VectorWithUnits<'a0, 'b1>) 
        : VectorWithUnits<'a0, 'b0*'b1> =
        match a, b with
        |VectorWithUnits a, VectorWithUnits b -> VectorWithUnits <| a * b

在这种情况下,单位附加到向量,乘法向量按预期正常工作。

问题在于,现在我们可以在Vector<'T>float本身上拥有单独且不同的度量单位注释。

您可以使用以下方法将具有度量单位的特定类型的数组转换为一组Vector

let toFloatVectors (array : float<'m>[]) : VectorWithUnits<float,'m>[]  =
    let arrs = array |> Array.chunkBySize (Vector<float>.Count)
    arrs |> Array.map (Array.map (float) >> Vector >> VectorWithUnits)

并返回:

let fromFloatVectors (vectors : VectorWithUnits<float,'m>[]) : float<'m>[] =
    let arr = Array.zeroCreate<float> (Array.length vectors)
    vectors |> Array.iteri (fun i uVec ->
        match uVec with
        |VectorWithUnits vec -> vec.CopyTo arr)
    arr |> Array.map (LanguagePrimitives.FloatWithMeasure<'m>)

一个hacky替代方案:

如果你放弃通用类型'T,你可以通过一些相当可怕的拳击和运行时强制转换使float Vector行为正常。这滥用了一个事实,即度量单位是运行时不再存在的编译时构造。

type FloatVectorWithUnits<[<Measure>]'b> = 
    |FloatVectorWithUnits of Vector<float<'b>>

    static member ( * ) (a : FloatVectorWithUnits<'b0>, b : FloatVectorWithUnits<'b1>) =
        match a, b with
        |FloatVectorWithUnits a, FloatVectorWithUnits b ->
            let c, d = box a :?> Vector<float<'b0*'b1>>, box b :?> Vector<float<'b0*'b1>>
            c * d |> FloatVectorWithUnits

答案 1 :(得分:1)

我想出了一个解决方案,满足了我的大部分要求(似乎)。它受TheInnerLight's ideas(包裹 jan 05, 2016 2:54:58 EM com.sun.faces.config.ConfigureListener contextInitialized INFO: Initializing Mojarra (1.2_12-b01-FCS) for context '/b2bws' jan 05, 2016 2:54:58 EM org.apache.catalina.core.StandardContext listenerStart SEVERE: Exception sending context initialized event to listener instance of class com.sun.faces.config.ConfigureListener java.lang.ExceptionInInitializerError at com.sun.faces.config.ConfigManager$ParseTask.<init>(ConfigManager.java:373) at com.sun.faces.config.ConfigManager.getConfigDocuments(ConfigManager.java:281) at com.sun.faces.config.ConfigManager.initialize(ConfigManager.java:203) at com.sun.faces.config.ConfigureListener.contextInitialized(ConfigureListener.java:196) at org.apache.catalina.core.StandardContext.listenerStart(StandardContext.java:4937) at org.apache.catalina.core.StandardContext.startInternal(StandardContext.java:5434) at org.apache.catalina.util.LifecycleBase.start(LifecycleBase.java:150) at org.apache.catalina.core.ContainerBase$StartChild.call(ContainerBase.java:1559) at org.apache.catalina.core.ContainerBase$StartChild.call(ContainerBase.java:1549) at java.util.concurrent.FutureTask.run(Unknown Source) at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source) at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source) at java.lang.Thread.run(Unknown Source) Caused by: com.sun.faces.config.ConfigurationException: java.lang.UnsupportedOperationException: This operation needs a license for Saxon-EE at com.sun.faces.config.DbfFactory.initSchema(DbfFactory.java:162) at com.sun.faces.config.DbfFactory.<clinit>(DbfFactory.java:120) ... 13 more Caused by: java.lang.UnsupportedOperationException: This operation needs a license for Saxon-EE at com.saxonica.jaxp.SchemaFactoryImpl.<init>(SchemaFactoryImpl.java:42) at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method) at sun.reflect.NativeConstructorAccessorImpl.newInstance(Unknown Source) at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(Unknown Source) at java.lang.reflect.Constructor.newInstance(Unknown Source) at java.lang.Class.newInstance(Unknown Source) at javax.xml.validation.SchemaFactoryFinder.loadFromService(Unknown Source) at javax.xml.validation.SchemaFactoryFinder._newFactory(Unknown Source) at javax.xml.validation.SchemaFactoryFinder.newFactory(Unknown Source) at javax.xml.validation.SchemaFactory.newInstance(Unknown Source) at com.sun.faces.config.DbfFactory.initSchema(DbfFactory.java:149) ... 14 more jan 05, 2016 2:54:58 EM org.apache.catalina.core.StandardContext startInternal SEVERE: Error listenerStart jan 05, 2016 2:54:58 EM org.apache.catalina.core.StandardContext startInternal SEVERE: Context [/b2bws] startup failed due to previous errors )的启发,但也为底层数组数据类型添加了一个包装器(称为Vector<'T>)。这样我们就可以跟踪单位,而在下面我们只处理原始数据,并且可以使用不是单位的ScalarField API。

简化,简单,快速和脏的实现将如下所示:

System.Numerics.Vector

我们现在可以使用这两种数据结构以相对优雅,高效的方式解决样本问题:

// units-aware wrapper for System.Numerics.Vector<'T>
type PackedScalars<[<Measure>] 'm> = struct
    val public Data: Vector<float>
    new (d: Vector<float>) = {Data = d}
    static member inline (*) (u: PackedScalars<'m1>, v: PackedScalars<'m2>) = u.Data * v.Data |> PackedScalars<'m1*'m2>
end

// unit-ware type, wrapping a raw array for easy stream processing
type ScalarField<[<Measure>] 'm> = struct
    val public Data: float[]
    member self.Item with inline get i                = LanguagePrimitives.FloatWithMeasure<'m> self.Data.[i]
                     and  inline set i (v: float<'m>) = self.Data.[i] <- (float v)
    member self.Packed 
           with inline get i                        = Vector<float>(self.Data, i) |> PackedScalars<'m>
           and  inline set i (v: PackedScalars<'m>) = v.Data.CopyTo(self.Data, i)
    new (d: float[]) = {Data = d}
    new (count: int) = {Data = Array.zeroCreate count}
end

最重要的是,为单元感知let xs = Array.init (simdWidth * 10) float |> ScalarField<m> let mutable rs = Array.zeroCreate (xs.Data |> Array.length) |> ScalarField<m^2> let chunks = (xs.Data |> Array.length) / simdWidth for i = 0 to chunks - 1 do let j = i * simdWidth let v = xs.Packed(j) // PackedScalars<m> let u = v * v // PackedScalars<m^2> rs.Packed(j) <- u 包装器重新实现常规数组操作可能很有用,例如

ScalarField

缺点:对于基础数字类型([<CompilationRepresentation(CompilationRepresentationFlags.ModuleSuffix)>] module ScalarField = let map f (sf: ScalarField<_>) = let mutable res = Array.zeroCreate sf.Data.Length |> ScalarField for i = 0 to sf.Data.Length do res.[i] <- f sf.[i] res )而言不是通用的,因为float没有通用替代品。为了使它具有通用性,我们必须实现第三个包装器floatWithMeasure,它还包装了基础原语:

Scalar

...意味着我们基本上不使用&#34;细化&#34;来跟踪单位。像type Scalar<'a, [<Measure>] 'm> = struct val public Data: 'a new (d: 'a) = {Data = d} end type PackedScalars<'a, [<Measure>] 'm when 'a: (new: unit -> 'a) and 'a: struct and 'a :> System.ValueType> = struct val public Data: Vector<'a> new (d: Vector<'a>) = {Data = d} static member inline (*) (u: PackedScalars<'a, 'm1>, v: PackedScalars<'a, 'm2>) = u.Data * v.Data |> PackedScalars<'a, 'm1*'m2> end type ScalarField<'a, [<Measure>] 'm when 'a: (new: unit -> 'a) and 'a: struct and 'a :> System.ValueType> = struct val public Data: 'a[] member self.Item with inline get i = Scalar<'a, 'm>(self.Data.[i]) and inline set i (v: Scalar<'a,'m>) = self.Data.[i] <- v.Data member self.Packed with inline get i = Vector<'a>(self.Data, i) |> PackedScalars<_,'m> and inline set i (v: PackedScalars<_,'m>) = v.Data.CopyTo(self.Data, i) new (d:'a[]) = {Data = d} new (count: int) = {Data = Array.zeroCreate count} end 这样的类型,但只能通过带有辅助类型/单位参数的包装类型。

但是,我仍然希望有人提出一个更好的主意。 :)