我有一个小的Java方法,用于在输入向量和矩阵之间执行点积。这是代码:
public void calcOutput() {
outputs = new float[output];
float sum = 0F;
for(int j = 0; j < output; j++) {
for(int i = 0; i < input; i++) {
sum += inputs[i] * weights[j][i];
}
outputs[j] = sum;
}
}
基本上应该做的是将我的输入向量'输入'并使用我称为“权重”的矩阵执行点积。然后将输出放在输出矢量'outputs'中。
如何更快或更高效?如果有帮助,我的体重矩阵也不需要是矩阵。我只需要一种方法来轻松访问相应的索引。
由于
答案 0 :(得分:3)
不,没有更好的事情。它是您可以实现的最简单的方法,该算法遵循良好的内存缓存方法,即外部循环遵循数组的外部索引,内部循环遍历一个子数组内的元素。
也许它可以帮助为内部数组使用临时变量,但我想JIT会处理这个问题。
此外,还有一个错误,sum
变量应该在外部循环的范围内,而不是方法范围。它需要在外循环的每次迭代时重置:
for(int j = 0; j < output; j++) {
// NOTE the line:
float sum = 0;
// and the reference to inner array:
byte[] row = weights[j];
for(int i = 0; i < input; i++) {
sum += inputs[i] * row[i];
}
outputs[j] = sum;
}
答案 1 :(得分:1)
这就是我要做的。通过反转外部和内部循环,可以减少inputs
数组中的查找次数。此外,您不需要sum
变量 - 您可以直接在outputs
数组中添加。
float[] outputs = new float[output];
for(int i = 0; i < input; i++) {
float inputsI = inputs[i];
for(int j = 0; j < output; j++) {
outputs[j] += inputsI * weights[j][i];
}
}
我希望这只会快一点。在几乎所有现实世界的应用程序中,不值得担心像这样微小的微小优化。
答案 2 :(得分:1)
有几种方法比编写香草点积更好。天真的实现将通过C2进行矢量化,但是顺序归约阶段是如此缓慢,以至于矢量化乘法的好处被抵消了。现在在Java(JDK10)中,最好的办法是使用部分和解展开数据依赖关系。 C2将发出标量代码,但是它将使用一些流水线操作,每个周期最多可以获得4个触发器。
float s0 = 0f;
float s1 = 0f;
float s2 = 0f;
float s3 = 0f;
float s4 = 0f;
float s5 = 0f;
float s6 = 0f;
float s7 = 0f;
for (int i = 0; i < size; i += 8) {
s0 = Math.fma(left[i + 0], right[i + 0], s0);
s1 = Math.fma(left[i + 1], right[i + 1], s1);
s2 = Math.fma(left[i + 2], right[i + 2], s2);
s3 = Math.fma(left[i + 3], right[i + 3], s3);
s4 = Math.fma(left[i + 4], right[i + 4], s4);
s5 = Math.fma(left[i + 5], right[i + 5], s5);
s6 = Math.fma(left[i + 6], right[i + 6], s6);
s7 = Math.fma(left[i + 7], right[i + 7], s7);
}
return s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7;
要尽快进行,您需要使用累加器进行显式矢量化。可以使用Project Panama Vector API编写这样的代码。
var sum1 = YMM_FLOAT.zero();
var sum2 = YMM_FLOAT.zero();
var sum3 = YMM_FLOAT.zero();
var sum4 = YMM_FLOAT.zero();
int width = YMM_FLOAT.length();
for (int i = 0; i < size; i += width * 4) {
sum1 = YMM_FLOAT.fromArray(left, i).fma(YMM_FLOAT.fromArray(right, i), sum1);
sum2 = YMM_FLOAT.fromArray(left, i + width).fma(YMM_FLOAT.fromArray(right, i + width), sum2);
sum3 = YMM_FLOAT.fromArray(left, i + width * 2).fma(YMM_FLOAT.fromArray(right, i + width * 2), sum3);
sum4 = YMM_FLOAT.fromArray(left, i + width * 3).fma(YMM_FLOAT.fromArray(right, i + width * 3), sum4);
}
return sum1.addAll() + sum2.addAll() + sum3.addAll() + sum4.addAll();
有关基准和深入分析,请参见此blog post。