使用ConcurrentHashMap进行并行化

时间:2013-03-01 23:02:50

标签: java concurrency parallel-processing hashmap

我有一个程序,我试图理解线程并行性。这个程序涉及硬币翻转并计算头尾数(以及硬币翻转总数)。

请参阅以下代码:

import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;

public class CoinFlip{


    // main
    public static void main (String[] args) {
        if (args.length != 2){
            System.out.println("CoinFlip #threads #iterations");
            return;
        }

        // check if arguments are integers
        int numberOfThreads = 0;
        long iterations = 0;

        try{
            numberOfThreads = Integer.parseInt(args[0]);
            iterations = Long.parseLong(args[1]);
        }catch(NumberFormatException e){
            System.out.println("error: I asked for numbers mate.");
            System.out.println("error: " + e);
            System.exit(1);
        }

        // ------------------------------
        // set time field
        // ------------------------------


        // create a hashmap
        ConcurrentHashMap <String, Long> universalMap = new ConcurrentHashMap <String, Long> ();

        // store count for heads, tails and iterations
        universalMap.put("HEADS", new Long(0));
        universalMap.put("TAILS", new Long(0));
        universalMap.put("ITERATIONS", new Long(0));

        long startTime = System.currentTimeMillis();

        Thread[] doFlip = new Thread[numberOfThreads];

        for (int i = 0; i < numberOfThreads; i ++){
            doFlip[i] = new Thread( new DoFlip(iterations/numberOfThreads, universalMap));
            doFlip[i].start();
        }

        for (int i = 0; i < numberOfThreads; i++){
            try{
                doFlip[i].join();
            }catch(InterruptedException e){
                System.out.println(e);
            }
        }

        // log time taken to accomplish task
        long elapsedTime = System.currentTimeMillis() - startTime;
        System.out.println("Runtime:" + elapsedTime);

        // print the output to check if the values are legal
        // iterations = heads + tails = args[1]
        System.out.println(
            universalMap.get("HEADS") + " " +
            universalMap.get("TAILS") + " " +
            universalMap.get("ITERATIONS") + "."
        );

        return;
    }



    private static class DoFlip implements Runnable{

        // local counters for heads/tails/count
        long heads = 0, tails = 0, iterations = 0;
        Random randomHT = new Random();

        // constructor values -----------------------
        long times = 0; // number of iterations
        ConcurrentHashMap <String, Long> map; // pointer to hash map

        DoFlip(long times, ConcurrentHashMap <String, Long> map){
            this.times = times;
            this.map = map;
        }

        public void run(){
            while(this.times > 0){
                int r = randomHT.nextInt(2); // 0 and 1

                if (r == 1){
                    this.heads ++;
                }else{
                    this.tails ++;
                }
                // System.out.println("Happening...");
                this.iterations ++;
                this.times --;
            }

            updateStats();
        }


        public void updateStats(){
            // read from hashmap and get the existing values
            Long nHeads = (Long)this.map.get("HEADS");
            Long nTails = (Long)this.map.get("TAILS");
            Long nIterations = (Long)this.map.get("ITERATIONS");

            // update values
            nHeads = nHeads + this.heads;
            nTails = nTails + this.tails;
            nIterations = nIterations + this.iterations;

            // push updated values to hashmap
            this.map.put("HEADS", nHeads);
            this.map.put("TAILS", nTails);
            this.map.put("ITERATIONS", nIterations);

        }
    }
}

我使用ConcurrentHashMap来存储不同的计数。显然,当返回错误的值时。

我编写了一个PERL脚本来检查头尾的值(总和)(每个线程单独),这似乎是合适的。我无法理解为什么我从hashmap中获得不同的值。

3 个答案:

答案 0 :(得分:4)

并发哈希映射为您提供有关 地图 本身的更改可见性的保证,而不是其值。在这种情况下,您从地图中检索一些值,保持它们一段任意时间,然后再尝试将它们存储到地图中。在读取和后续写入之间, 地图上可能发生了任意数量的操作

并发哈希映射中的并发只是保证,例如,如果我将一个值放入一个映射中,我实际上能够在另一个线程中读取该值(也就是说它将是可见的)。

当您更新共享计数器时,您需要做的是确保访问地图的所有线程等待轮到。为此,您必须在AtomicInteger上使用像'addAndGet`这样的原子操作:

this.map.get("HEADS").addAndGet(this.heads);

或者您需要手动同步读取和写入(最容易通过在地图上同步来完成):

synchronized(this.map) {
    Long currentHeads = this.map.get("HEADS");
    this.map.put("HEADS", Long.valueOf(currentHeads.longValue() + this.heads);
}

就个人而言,我更愿意随时使用SDK,因此我会使用Atomic数据类型。

答案 1 :(得分:2)

你应该使用AtomicLongs作为值,你应该只创建一次并增加它们而不是get / put。

 ConcurrentHashMap <String, AtomicLong> universalMap = new ConcurrentHashMap <String, AtomicLong> ();
 ...
 universalMap.put("HEADS", new AtomicLong(0));
 universalMap.put("TAILS", new AtomicLong(0));
 universalMap.put("ITERATIONS", new AtomicLong(0));
 ...
 public void updateStats(){
        // read from hashmap and get the existing values
        this.map.get("HEADS").getAndAdd(heads);
        this.map.get("TAILS").getAndAdd(tails);
        this.map.get("ITERATIONS").getAndAdd(iterations);
 }

龙是不变的。

一个例子:

Thread 1: get 0
Thread 2: get 0
Thread 2: put 10
Thread 3: get 10
Thread 3: put 15
Thread 1: put 5

现在你的地图包含5而不是20

基本上你的问题不是地图。您可以使用常规HashMap,因为您不修改它。当然,您必须制作map字段final

答案 2 :(得分:2)

一些事情。一个你真的不需要使用ConcurrentHashMap。 ConcurrentHashMap仅在处理并发put / removed时才有用。在这种情况下,只要键只使用UnmodifiableMap来证明这一点,地图就是相当静态的。

最后,如果你正在处理并发添加,你真的应该考虑使用LongAdder。当发生许多并行添加时,它会更好地扩展,直到最后你不需要担心计数。

public class HeadsTails{
    private final Map<String, LongAdder> map;
    public HeadsTails(){
       Map<String,LongAdder> local = new HashMap<String,LongAdder>();
       local.put("HEADS", new LongAdder());
       local.put("TAILS", new LongAdder());
       local.put("ITERATIONS", new LongAdder());
       map = Collections.unmodifiableMap(local);
    }
    public void count(){
        map.get("HEADS").increment();
        map.get("TAILS").increment();
    }
    public void print(){
        System.out.println(map.get("HEADS").sum());
         /// etc...
    }
}

我的意思是,实际上我甚至不会使用地图......

public class HeadsTails{
    private final LongAdder heads = new LongAdder();
    private final LongAdder tails = new LongAdder();
    private final LongAdder iterations = new LongAdder();
    private final Map<String, LongAdder> map;
    public void count(){
        heads.increment();
        tails.increment();
    }
    public void print(){
        System.out.println(iterations.sum());
    }
}