使用Arrayfire.jl进行矢量交叉产品的可能解决方法?

时间:2017-01-12 15:47:18

标签: julia arrayfire

我正在尝试使用ArrayFire.jl进行矢量数学运算,但是矢量交叉产品的功能未在Arrayfire中实现。 是否有一种解决方法可以使用Julia的Arrayfire.jl包装器以高效的方式计算它?由于设备和主机之间的所有数据传输,以一种天真的方式定义函数非常慢,而且我不了解包装函数,足以弄清楚如何解决这个问题。

cross(a::ArrayFire.AFArray, b::ArrayFire.AFArray) = ArrayFire.AFArray([a[2]*b[3]-a[3]*b[2]; a[3]*b[1]-a[1]*b[3]; a[1]*b[2]-a[2]*b[1]]);

2 个答案:

答案 0 :(得分:1)

为了回答自己,可以使用circshift()函数完成交叉积,以在GPU中创建移位向量,然后可以进行逐元素乘法和减法。它不是最优雅的方式,但它有效。

function cross(a::ArrayFire.AFArray{Float32,1}, b::ArrayFire.AFArray{Float32,1})
    ashift = circshift(a, [-1]);
    ashift2 = circshift(a, [-2]);
    bshift = circshift(b, [-2]);
    bshift2 = circshift(b, [-1]);
    c::ArrayFire.AFArray{Float32,1} = ashift.*bshift - ashift2.*bshift2;
end

答案 1 :(得分:0)

我认为以下内容应该有效:

function cross!(c::AFArray, a::AFArray, b::AFArray)
    c[1] = a[2]*b[3]-a[3]*b[2]
    c[2] = a[3]*b[1]-a[1]*b[3]
    c[3] = a[1]*b[2]-a[2]*b[1]
end

c = AFArray(zeros(3))
a = AFArray([1.0, 2, 3])
b = AFArray([3.0, 4, 5])

cross!(c, a, b)