nd4j中广播的矩阵乘法

时间:2018-12-05 06:03:38

标签: nd4j

在python中,假设

a = np.array(range(0,12)).reshape(2,2,3)
b = np.array(range(0,6)).reshape(3,2)
c = np.matmul(a,b) // a @ b

我们有

a: array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]]])

b: array([[0, 1],
       [2, 3],
       [4, 5]])

c: array([[[10, 13],
        [28, 40]],

       [[46, 67],
        [64, 94]]])

有人可以帮助我实现Java nd4j 中的等效操作而无需for循环吗?我尝试过broadcast.mul,但事实证明broadcast.mul是逐元素乘法。我没有找到mmul的任何广播操作。

1 个答案:

答案 0 :(得分:1)

我自己弄清楚了。如果有人需要,答案如下所示。 使用Nd4j.tensorMmul,可以轻松实现矩阵广播。例如

val a = Nd4j.create(0d to 11d by 1d toArray, Array[Int](2, 2, 3))
val b = Nd4j.create(0d to 5d by 1d toArray, Array[Int](3, 2))
Nd4j.tensorMmul(a, b, Array(Array(2), Array(0))) // matrix broadcast

这是scala的代码。对于Java,只需更改代码即可创建数组。