将转置矩阵与jblas相乘

时间:2014-03-09 22:12:58

标签: matrix-multiplication

我正在用Jblas实现反向传播和梯度下降。

第1层是向量A:DoubleMatrix(M,1)

第2层是向量B:DoubleMatrix(N,1)

它们之间是权重W:DoubleMatrix(M,N)

在正向传球期间,我将B = W \乘以A

W.mmulti(A, B)

在反向传播期间,我正在计算A =(B ^ T \乘以W)^ T

A = B.transpose().mmul(W).transpose()

我编写了代码,以便所有内容都可以就地计算并且非常快。但是Jblas transpose()方法创建了一个全新的对象并复制了所有数据,在每次迭代时调用两次都非常昂贵。有没有办法在乘法过程中使用DoubleMatrix转置,而不进行所有这些复制?看起来内部实现起来很容易 - 使用相同的数据对象,但是将调用切换到行和列。

2 个答案:

答案 0 :(得分:1)

我问自己一个很好的问题。 我没有答案,但这个简单的事实可以挽救其中一个转换:

(B ^ T \乘以W)^ T = W ^ T * B

所以你要写

A = W.transpose().mmul(B)

答案 1 :(得分:0)

nd4j的作者在这里。我们添加了nd数组和常量时间转置以及矩阵乘法,它们与各种我们称之为“后端”的东西一起工作,其中一个是jblas。

API也与jblas非常相似。

INDArray arr = Nd4j.create(2,2); INDArray结果= arr.mmul(arr);

最好的部分是你得到cuda和(很快就会被添加!)opencl免费。

对于大多数事情我们已经迁移到netlib-java但是继续提供jblas(因为它非常稳定)但是在实践中相互运行后端我们已经找到了在jblas上进行复制操作的整洁操作(因为内部不能) t支持抵消)