网格使用线程

时间:2015-06-01 23:38:08

标签: java multithreading

我正在研究分类问题,我已经实现了网格搜索算法,以便找到最佳准确度。我的问题是程序的执行时间大约是2小时,我试图通过使用线程来改善这段时间。显然我做错了,因为即使在实现线程之后执行时间也是一样的。贝娄是算法。

我必须指定这是我第一次使用线程,我已经阅读了一些关于Executors的好东西,但我无法弄清楚如何实现它们。

public static void gridSearch(Dataset ds)
        {
            double bestAcc = 0;

            for (int i = -5; i < 15; i++) {
                double param1 = Math.pow(2, i);
                for (int j = -15; j < 3; j++) {
                    double param2 = Math.pow(2, j);

                    int size = 10;
                    CrossValidation[] works = new CrossValidation[size];
                    Thread[] threads = new Thread[size];
                    for (int k=1;k<=size;k++) {
                        CrossValidation po = new CrossValidation(param1, param2, ds);;
                        works[k-1] = po;
                        Thread t = new Thread(po);
                        threads[k-1] = t;
                        t.start();
                    }
                    for (int k = 0; k < size; k++) {
                        try { threads[k].join(); } catch (InterruptedException ex) {}
                        double accuracy = works[k].getAccuracy();
                        accuracy /= 106;
                        if (accuracy > bestAccuracy)
                            bestAcc = accuracy;
                    }
                }
            }
            System.out.println("Best accuracy: " + bestAcc);
        }

CrossValidation类实现了Runnable,并且有一个返回准确性的方法getAccuracy

请帮我弄清楚我做错了什么,以便缩短执行时间。

2 个答案:

答案 0 :(得分:3)

您的问题似乎是您为每个参数设置10个线程而不是为每个参数设置启动线程。仔细看看你在这里做了什么。您正在生成param1param2,然后启动10个与这些参数一起使用的线程 - 冗余。在此之后,您将等待这些线程完成,然后再重新开始。

但不用担心,我为你准备了一些东西......

我想告诉你如何让线程池做你真正想要实现的目标。一旦你运行它就会更容易理解并注意:

您可以下载整个示例here

首先,您需要WorkerThreadCVResult之类的内容来返回结果。这是您要执行CrossValidation算法的地方:

public static class CVResult {      
    public double param1;
    public double param2;
    public double accuracy;
}

public static class WorkerThread implements Runnable {

    private double param1;
    private double param2;
    private double accuracy;

    public WorkerThread(double param1, double param2){   
        this.param1 = param1;
        this.param2 = param2;
    }

    @Override
    public void run() {
        System.out.println(Thread.currentThread().getName() +
                " [parameter1] " + param1 + " [parameter2]: " + param2);
        processCommand();
    }

    private void processCommand() {
        try {

            Thread.sleep(500);
            ;
            /*
             * ### PERFORM YOUR CROSSVALIDATION ALGORITHM HERE ###
             */

            this.accuracy = this.param1 + this.param2;

            // Give back result:

            CVResult result = new CVResult();

            result.accuracy = this.accuracy;
            result.param1 = this.param1;
            result.param2 = this.param2;

            Main.addResult(result);

        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

您还需要确保您有权访问ExecutorServiceList<Future>ExecutorService将处理您的线程,我们将初始化线程数作为CPU可用的核心数。 这将确保没有比CPU上可用的核心更多的线程运行 - 但是 - 没有任务丢失,因为每个线程都排队并在另一个线程完成后启动。你很快就会看到。 List<Future>将允许我们在继续主线程之前等待所有线程完成。 List<CVResult>当然是保存线程添加的结果(请注意它是同步的,因为多个线程将访问它)。

private static ExecutorService executor = null;
private static List<Future> futures = new ArrayList<>();
private static List<CVResult> resultList = Collections.synchronizedList(new ArrayList<CVResult>());

这就是gridSearch()的样子。你不必在这里初始化executor ..你可以在任何你想要的地方做到这一点:

public static void gridSearch(/*Dataset ds*/)
{
    double bestAcc = 0;

    int cores = Runtime.getRuntime().availableProcessors();
    executor = Executors.newFixedThreadPool(cores);

    for (int i = -5; i < 15; i++) {

        double param1 = Math.pow(2, i);

        for (int j = -15; j < 3; j++) { 

            double param2 = Math.pow(2, j);  

            Runnable worker = new WorkerThread(param1, param2);
            futures.add(executor.submit(worker));
        }
    }

    System.out.println("Waiting for all threads to terminate ..");

    // Joining all threads in order to wait for all to finish
    // before returning from gridSearch()
    for (Future future: futures) {
        try {
            future.get(100, TimeUnit.SECONDS); 
        } catch (Throwable cause) {
            // process cause
        }
    }

    System.out.println("Printing results ..");

    for(CVResult result : resultList) {
        System.out.println("Acc: " + result.accuracy + 
                " for param1: " + result.param1 + 
                " | param2: " + result.param2);
    }
}

最后但并非最不重要的是,这是一个将结果添加到列表的同步方法:

public static void addResult(CVResult accuracy) {
    synchronized( resultList ) {
        resultList.add(accuracy);
    }       
}

如果你在主要电话中打电话,例如像这样:

public static void main(String[] args) {
    gridSearch(/* params */);       
    System.out.println("All done.");
}

您将获得如下输出:

...
pool-1-thread-5 [parameter1] 0.0625 [parameter2]: 3.0517578125E-5
param1 0.03125
param2 1.0
pool-1-thread-4 [parameter1] 0.0625 [parameter2]: 0.25
param1 0.0625
param2 0.03125
...
Printing results ..
...
Acc: 16384.5 for param1: 16384.0 | param2: 0.5
Acc: 16386.0 for param1: 16384.0 | param2: 2.0
...
All done.

答案 1 :(得分:0)

可能因为线程创建/拆除开销增加了运行线程所需的时间,请使用Executors解决此问题。 This will help you get started。如已经评论过,您的处理器也可能没有可用的处理线程或物理内核来同时执行您的线程。

更突出的是,在每次-15到3次迭代之间,你必须等待。要解决此问题,请在处理完所有内容后将等待和处理移至for循环的末尾。这样,在开始下一批之前,最后10个线程不需要完全。此外,我建议在处理结果之前使用CountDownLatch等待完全完成。