Fork加入优化

时间:2014-03-05 07:32:03

标签: java multithreading concurrency java-8 fork-join

我想要什么

我想研究fork / join算法的优化。通过优化,我的意思是计算最佳线程数,或者如果您需要 - 计算SEQUENTIAL_THRESHOLD(参见下面的代码)。

// PSEUDOCODE
Result solve(Problem problem) { 
    if (problem.size < SEQUENTIAL_THRESHOLD)
        return solveSequentially(problem);
    else {
        Result left, right;
        INVOKE-IN-PARALLEL { 
            left = solve(extractLeftHalf(problem));
            right = solve(extractRightHalf(problem));
        }
        return combine(left, right);
    }
}

我怎么想象

例如,我想计算大数组的乘积。然后我只评估所有组件并获得最佳线程数量:

SEQUENTIAL_THRESHOLD = PC * IS / MC(仅举例)

PC - 处理器核心数量; IS - 常量,表示具有一个处理器内核的最佳阵列大小和对数据的最简单操作(例如读取); MC - 乘以运营成本;

假设MC = 15; PC = 4且IS = 10000; SEQUENTIAL_THRESHOLD = 2667。如果子任务数组大于2667,我会把它分叉。

广泛的问题

  1. 是否可以这样制作SEQUENTIAL_THRESHOLD公式?
  2. 是否可以为更复杂的计算完成相同的工作:不仅是对数组/集合的操作和排序?
  3. 狭隘的问题:

    对于数组/集合/排序的SEQUENTIAL_THRESHOLD的计算,是否已经存在一些调查?他们是如何做到的?

    2014年3月7日更新:

    1. 如果没有办法为阈值计算编写单个公式,我可以编写一个将在PC上执行预定义测试的util,并获得最佳阈值吗?这也不可能吗?
    2. Java 8 Streams API可以做什么?它可以帮助我吗? Java 8 Streams API是否消除了Fork / Join的需求?

3 个答案:

答案 0 :(得分:5)

除非您与执行环境保持密切关系,否则绝对无法计算合适的阈值。我在sourceforge.net上维护了一个fork / join项目,这是我在大多数内置函数中使用的代码:

private int calcThreshold(int nbr_elements, int passed_threshold) {

  // total threads in session
  // total elements in array
  int threads = getNbrThreads();
  int count   = nbr_elements + 1;

  // When only one thread, it doesn't pay to decompose the work,
  //   force the threshold over array length
  if  (threads == 1) return count;    

  /*
   * Whatever it takes
   * 
   */
  int threshold = passed_threshold;

  // When caller suggests a value
  if  (threshold > 0) {

      // just go with the caller's suggestion or do something with the suggestion

  } else {
      // do something usful such as using about 8 times as many tasks as threads or
      //   the default of 32k
      int temp = count / (threads << 3);
      threshold = (temp < 32768) ? 32768 : temp;

  } // endif    

  // whatever
  return threshold;

}

3月9日编辑:

你怎么可能拥有一个通用工具,它不仅可以知道处理器速度,可用内存,处理器数量等(物理环境),还可以知道软件的用途?答案是你不能。这就是为什么你需要为每个环境开发一个例程。上面的方法是我用于基本数组(向量)。我使用另一个方法进行大多数矩阵处理:

// When very small, just spread every row
if  (count < 6) return 1;

// When small, spread a little 
if  (count < 30) return ((count / (threads << 2) == 0)? threads : (count / (threads << 2)));  

// this works well for now
return ((count / (threads << 3) == 0)? threads : (count / (threads << 3))); 

就Java8流而言:他们使用F / J框架,你无法指定一个阈值。

答案 1 :(得分:3)

由于以下几个原因,你无法将其简化为一个简单的公式:

  • 每台PC的参数差异很大,不仅取决于核心,还取决于其他因素,如RAM计时或后台任务。

  • Java本身在执行过程中即时优化循环。因此,几秒钟之后瞬间完美的设置可能是次优的。或者更糟糕的是:调整可能会阻止完美的优化。

我能看到的唯一方法是以某种形式的AI或遗传算法动态调整值。然而,这包括程序经常检查非最佳设置以确定当前设置是否仍然是最佳的。因此,如果获得的速度实际上高于尝试其他设置所损失的速度,则值得怀疑。最后可能只是初始学习阶段的解决方案,而进一步执行则将这些训练值用作固定数字。

由于这不仅花费时间而且大大增加了代码复杂性,我不认为这是大多数程序的选项。通常,首先不使用Fork-Join更有利,因为还有许多其他并行化选项可能更适合这个问题。

“遗传”算法的一个想法是测量每次运行的循环效率,然后有一个不断更新的背景散列映射loop-parameters -> execution time,并且为大多数运行选择最快的设置。

答案 2 :(得分:1)

这是一个非常有趣的问题需要调查。我编写了这个简单的代码来测试顺序阈值的最佳值。我无法得出任何具体的结论,很可能是因为我在只有2个处理器的旧笔记本电脑上运行它。许多运行后唯一一致的观察结果是所用时间迅速下降,直到连续阈值为100.尝试运行此代码并让我知道您发现了什么。同样在底部,我附加了一个python脚本,用于绘制结果,以便我们可以直观地看到趋势。

import java.io.FileWriter;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;

public class Testing {

static int SEQ_THRESHOLD;

public static void main(String[] args) throws Exception {
    int size = 100000;
    int[] v1 = new int[size];
    int[] v2 = new int[size];
    int[] v3 = new int[size];
    for (int i = 0; i < size; i++) {
        v1[i] = i;  // Arbitrary initialization
        v2[i] = 2 * i; // Arbitrary initialization
    }
    FileWriter fileWriter = new FileWriter("OutTime.dat");

    // Increment SEQ_THRESHOLD and save time taken by the code to run in a file
    for (SEQ_THRESHOLD = 10; SEQ_THRESHOLD < size; SEQ_THRESHOLD += 50) {
        double avgTime = 0.0;
        int samples = 5;
        for (int i = 0; i < samples; i++) {
            long startTime = System.nanoTime();
            ForkJoinPool fjp = new ForkJoinPool();
            fjp.invoke(new VectorAddition(0, size, v1, v2, v3));
            long endTime = System.nanoTime();
            double secsTaken = (endTime - startTime) / 1.0e9;
            avgTime += secsTaken;
        }
        fileWriter.write(SEQ_THRESHOLD + " " + (avgTime / samples) + "\n");
    }

    fileWriter.close();
}
}

class VectorAddition extends RecursiveAction {

int[] v1, v2, v3;
int start, end;

VectorAddition(int start, int end, int[] v1, int[] v2, int[] v3) {
    this.start = start;
    this.end = end;
    this.v1 = v1;
    this.v2 = v2;
    this.v3 = v3;
}

int SEQ_THRESHOLD = Testing.SEQ_THRESHOLD;

@Override
protected void compute() {
    if (end - start < SEQ_THRESHOLD) {
        // Simple vector addition
        for (int i = start; i < end; i++) {
            v3[i] = v1[i] + v2[i];
        }
    } else {
        int mid = (start + end) / 2;
        invokeAll(new VectorAddition(start, mid, v1, v2, v3),
                new VectorAddition(mid, end, v1, v2, v3));
    }
}
}

这是用于绘制结果的Python脚本:

from pylab import *

threshold = loadtxt("./OutTime.dat", delimiter=" ", usecols=(0,))
timeTaken = loadtxt("./OutTime.dat", delimiter=" ", usecols=(1,))

plot(threshold, timeTaken)
show()