Go:可重复使用的屏障,如Java的CyclicBarrier?

时间:2014-07-21 15:41:14

标签: go

使用Google Go,我正在尝试同步多个线程,对图像执行迭代过滤。我的代码基本上就像这里概述的那样:

func filter(src *image.Image, dest *image.Image, start, end, runs int, barrier ??) {
    for i:= 0; i < runs; i++ {
        // ... do image manipulation ...

        // barrier.Await() would work here

        if start == 1 {
            // the first thread switches the images for the next iteration step
            switchImgs(src, dest)
        }

        // barrier.Await() again
     }
}

func main() {
    //...
    barrier := sync.BarrierNew(numberOfThreads)
    for i := 0; i < numberOfThreads; i++ {
        go filter(..., barrier)
    }

问题是我需要一个可重复使用的屏障,就像Java的CyclicBarrier一样,将线程数设置为其计数器值。不幸的是,我发现的唯一类似于障碍的实现是sync.WaitGroup。但是WaitGroup不能原子地重置为它之前的计数器值。它只提供一个普通的Wait()函数,不会重置计数器值。

是否存在实现我想要的“Go idiomatic”方式,或者我应该实现自己的CyclicBarrier?非常感谢你的帮助!

2 个答案:

答案 0 :(得分:0)

我不完全明白CyclicBarrier是如何工作的,所以如果我离开的话,请原谅。

围绕SyncGroup的一个非常简单的包装器应该可以胜任,例如:

type Barrier struct {
    NumOfThreads int
    wg           sync.WaitGroup
}

func NewBarrier(num int) (b *Barrier) {
    b = &Barrier{NumOfThreads: num}
    b.wg.Add(num)
    return
}

func (b *Barrier) Await() {
    b.wg.Wait()
    b.wg.Add(b.NumOfThreads)
}

func (b *Barrier) Done() {
    b.wg.Done()
}
func filter(src *image.Image, dest *image.Image, start, end, runs int, barrier *Barrier) {
    for i := 0; i < runs; i++ {
        // ... do image manipulation ...
        //this filter is done, say so by using b.Done()
        b.Done()
        b.Await()
        if start == 1 {
            // the first thread switches the images for the next iteration step
            //switchImgs(src, dest)
        }

        b.Done()
        b.Await()
    }
}

func main() {
    barrier := NewBarrier(5)
    for i := 0; i < barrier.NumOfThreads; i++ {
        go filter(1, barrier)
    }
}

答案 1 :(得分:0)

您可以使用sync.Cond来实现CyclicBarrier,请参阅source code of java's CyclicBarrier

这是CyclicBarrier的最小版本(没有超时,没有线程中断): http://play.golang.org/p/5JSNTm0BLe

type CyclicBarrier struct {
    generation int
    count      int
    parties    int
    trip       *sync.Cond
}

func (b *CyclicBarrier) nextGeneration() {
    // signal completion of last generation
    b.trip.Broadcast()
    b.count = b.parties
    // set up next generation
    b.generation++
}

func (b *CyclicBarrier) Await() {
    b.trip.L.Lock()
    defer b.trip.L.Unlock()

    generation := b.generation

    b.count--
    index := b.count
    //println(index)

    if index == 0 {
        b.nextGeneration()
    } else {
        for generation == b.generation {
            //wait for current generation complete
            b.trip.Wait()
        }
    }
}

func NewCyclicBarrier(num int) *CyclicBarrier {
    b := CyclicBarrier{}
    b.count = num
    b.parties = num
    b.trip = sync.NewCond(&sync.Mutex{})

    return &b
}