如何避免forkjoin continuation

时间:2016-01-31 22:13:48

标签: java multithreading continuations thread-local forkjoinpool

  

这个问题与如何使用ThreadLocal无关。我的问题是   关于ForkJoinPool继续使用ForkJoinTask.compute()的副作用,它破坏了ThreadLocal契约。

ForkJoinTask.compute()中,我拉出一个任意的静态ThreadLocal。

该值是一些任意有状态对象,但在compute()调用结束后不具有状态。换句话说,我准备threadlocal对象/状态,使用它,然后处置。

原则上你会把那个状态放在ForkJoinTasK中,但只是假设这个线程本地值在第三方lib中我无法改变。因此是静态threadlocal,因为它是所有任务实例将共享的资源。

我预计,测试并证明简单的ThreadLocal只会初始化一次,当然。这意味着由于ForkJoinTask.join()调用下的线程延续,我的compute()方法甚至可以在退出之前再次调用。 这会暴露先前计算调用中使用的对象的状态,许多堆栈帧更高。

您如何解决不良曝光问题?

我目前看到的唯一方法是确保每个compute()调用都有新的线程,但这会破坏F / J池的延续,并可能危险地爆炸线程数。

在JRE核心中有没有什么可以做的事情来备份自第一个ForkJoinTask以来改变的TL并恢复整个threadlocal映射,好像每个task.compute都是第一个在线程上运行的?

感谢。

package jdk8tests;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.atomic.AtomicInteger;

public class TestForkJoin3 {

    static AtomicInteger nextId = new AtomicInteger();
    static long T0 = System.currentTimeMillis();
    static int NTHREADS = 5;
    static final ThreadLocal<StringBuilder> myTL = ThreadLocal.withInitial( () -> new StringBuilder());

    static void log(Object msg) {
        System.out.format("%09.3f %-10s %s%n", new Double(0.001*(System.currentTimeMillis()-T0)), Thread.currentThread().getName(), " : "+msg);
    }

    public static void main(String[] args) throws Exception {
        ForkJoinPool p = new ForkJoinPool(
                NTHREADS,
                pool -> {
                    int id = nextId.incrementAndGet(); //count new threads
                    log("new FJ thread "+ id);
                    ForkJoinWorkerThread t = new ForkJoinWorkerThread(pool) {/**/};
                    t.setName("My FJThread "+id);
                    return t;
                },
                Thread.getDefaultUncaughtExceptionHandler(),
                false
        );

        LowercasingTask t = new LowercasingTask("ROOT", 3);
        p.invoke(t);

        int nt = nextId.get();
        log("number of threads was "+nt);
        if(nt > NTHREADS)
            log(">>>>>>> more threads than prescribed <<<<<<<<");
    }


    //=====================

    static class LowercasingTask extends RecursiveTask<String> {
        String name;
        int level;
        public LowercasingTask(String name, int level) {
            this.name = name;
            this.level = level;
        }

        @Override
        protected String compute() {
            StringBuilder sbtl = myTL.get();
            String initialValue = sbtl.toString();
            if(!initialValue.equals(""))
                log("!!!!!! BROKEN ON START!!!!!!! value = "+ initialValue);

            sbtl.append(":START");

            if(level>0) {
                log(name+": compute level "+level);
                try {Thread.sleep(10);} catch (InterruptedException e) {e.printStackTrace();}

                List<LowercasingTask> tasks = new ArrayList<>();
                for(int i=1; i<=9; i++) {
                    LowercasingTask lt = new LowercasingTask(name+"."+i, level-1);
                    tasks.add(lt);
                    lt.fork();
                }

                for(int i=0; i<tasks.size(); i++) { //this can lead to compensation threads due to l1.join() method running lifo task lN
                //for(int i=tasks.size()-1; i>=0; i--) { //this usually has the lN.join() method running task lN, without compensation threads.
                    tasks.get(i).join();
                }

                log(name+": returning from joins");

            }

            sbtl.append(":END");

            String val = sbtl.toString();
            if(!val.equals(":START:END"))
                log("!!!!!! BROKEN AT END !!!!!!! value = "+val);

            sbtl.setLength(0);
            return "done";
        }
    }

}

1 个答案:

答案 0 :(得分:2)

我不相信。不是一般的,特别是不适用于ForkJoinTask,其中任务应该是孤立对象上的纯函数。

有时可以在任务的开始和之前将任务的顺序更改为fork和join。这样子任务将在返回之前初始化并处理线程局部。如果这是不可能的,也许您可​​以将线程局部视为堆栈并推送,清除和恢复每个连接的值。