如何使我的dot产品方法更快或更高效?

时间:2016-11-05 03:21:34

标签: java optimization matrix dot-product

我有一个小的Java方法,用于在输入向量和矩阵之间执行点积。这是代码:

    public void calcOutput() {
    outputs = new float[output];
    float sum = 0F;

    for(int j = 0; j < output; j++) {
        for(int i = 0; i < input; i++) {
            sum += inputs[i] * weights[j][i];
        }

        outputs[j] = sum;
    }
}

基本上应该做的是将我的输入向量'输入'并使用我称为“权重”的矩阵执行点积。然后将输出放在输出矢量'outputs'中。

如何更快或更高效?如果有帮助,我的体重矩阵也不需要是矩阵。我只需要一种方法来轻松访问相应的索引。

由于

3 个答案:

答案 0 :(得分:3)

不,没有更好的事情。它是您可以实现的最简单的方法,该算法遵循良好的内存缓存方法,即外部循环遵循数组的外部索引,内部循环遍历一个子数组内的元素。

也许它可以帮助为内部数组使用临时变量,但我想JIT会处理这个问题。

此外,还有一个错误,sum变量应该在外部循环的范围内,而不是方法范围。它需要在外循环的每次迭代时重置:

for(int j = 0; j < output; j++) {
    // NOTE the line:
    float sum = 0;
    // and the reference to inner array:
    byte[] row = weights[j];
    for(int i = 0; i < input; i++) {
        sum += inputs[i] * row[i];
    }

    outputs[j] = sum;
}

答案 1 :(得分:1)

这就是我要做的。通过反转外部和内部循环,可以减少inputs数组中的查找次数。此外,您不需要sum变量 - 您可以直接在outputs数组中添加。

    float[] outputs = new float[output];

    for(int i = 0; i < input; i++) {
        float inputsI = inputs[i];
        for(int j = 0; j < output; j++) {
            outputs[j] += inputsI * weights[j][i];
        }

    }

我希望这只会快一点。在几乎所有现实世界的应用程序中,不值得担心像这样微小的微小优化。

答案 2 :(得分:1)

有几种方法比编写香草点积更好。天真的实现将通过C2进行矢量化,但是顺序归约阶段是如此缓慢,以至于矢量化乘法的好处被抵消了。现在在Java(JDK10)中,最好的办法是使用部分和解展开数据依赖关系。 C2将发出标量代码,但是它将使用一些流水线操作,每个周期最多可以获得4个触发器。

float s0 = 0f;
float s1 = 0f;
float s2 = 0f;
float s3 = 0f;
float s4 = 0f;
float s5 = 0f;
float s6 = 0f;
float s7 = 0f;
for (int i = 0; i < size; i += 8) {
  s0 = Math.fma(left[i + 0],  right[i + 0], s0);
  s1 = Math.fma(left[i + 1],  right[i + 1], s1);
  s2 = Math.fma(left[i + 2],  right[i + 2], s2);
  s3 = Math.fma(left[i + 3],  right[i + 3], s3);
  s4 = Math.fma(left[i + 4],  right[i + 4], s4);
  s5 = Math.fma(left[i + 5],  right[i + 5], s5);
  s6 = Math.fma(left[i + 6],  right[i + 6], s6);
  s7 = Math.fma(left[i + 7],  right[i + 7], s7);
}
return s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7;

要尽快进行,您需要使用累加器进行显式矢量化。可以使用Project Panama Vector API编写这样的代码。

var sum1 = YMM_FLOAT.zero();
var sum2 = YMM_FLOAT.zero();
var sum3 = YMM_FLOAT.zero();
var sum4 = YMM_FLOAT.zero();
int width = YMM_FLOAT.length();
for (int i = 0; i < size; i += width * 4) {
  sum1 = YMM_FLOAT.fromArray(left, i).fma(YMM_FLOAT.fromArray(right, i), sum1);
  sum2 = YMM_FLOAT.fromArray(left, i + width).fma(YMM_FLOAT.fromArray(right, i + width), sum2);
  sum3 = YMM_FLOAT.fromArray(left, i + width * 2).fma(YMM_FLOAT.fromArray(right, i + width * 2), sum3);
  sum4 = YMM_FLOAT.fromArray(left, i + width * 3).fma(YMM_FLOAT.fromArray(right, i + width * 3), sum4);
}
return sum1.addAll() + sum2.addAll() + sum3.addAll() + sum4.addAll();

有关基准和深入分析,请参见此blog post