使用 golang 实现类似 pthread_barrier_t 语义的 barrier 对象

看到golang标准库sync package WaitGroup 类型, 本以为是golang 版本的 barrier 对象实现,看到文档给出的使用示例:

 var wg sync.WaitGroup
    var urls = []string{
            "http://www.golang.org/",
            "http://www.google.com/",
            "http://www.somestupidname.com/",
    }
    for _, url := range urls {           
            // Increment the WaitGroup counter.
            wg.Add(1)            
            // Launch a goroutine to fetch the URL.
            go func(url string) {                   
             // Decrement the counter when the goroutine completes.
                    defer wg.Done()                    
                    // Fetch the URL.
                    http.Get(url)
            }(url)
    }    
    // Wait for all HTTP fetches to complete.
    wg.Wait()

可以看出WaitGroup 类型主要用于某个goroutine(调用Wait() 方法的那个),  等待个数不定goroutine(内部调用Done() 方法),

Add 方法对内部计数,添加或减少,Done方法其实是Add(-1);

与pthread_barrier_t 有着语义上的差别,pthread_barrier_wait() 的调用者之间互相等待,就好比5名队员(线程)参加跨栏比赛,使用 pthread_barrier_init 初始化最后一个参数为5,  五个队员都是好基友, 定了规矩, 不管谁先到栏杆, 都要等队友,直到最后一名队员跨过栏时,然后同一起步点再次出发。下面时使用pthread_barrier_t 的简单示例 5个线程,每个线程拥有一个私有数组,及增量数字:

#define _GNU_SOURCE 

#include <pthread.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#define NTHR 5
#define NARR 6
#define INLOOPS 1000
#define OUTLOOPS 10
#define err_abort(code,text) do { \
    char errbuf[128] = {0};         \
    fprintf (stderr, "%s at \"%s\":%d: %s\n", \
        (text), __FILE__, __LINE__, strerror_r(code,errbuf,128)); \
    abort (); \
} while (0)

typedef struct thrArg {
    pthread_t   tid;
    int         incr;
    int         arr[NARR];
}thrArg;

pthread_barrier_t   barrier;
thrArg  thrs[NTHR];

void *thrFunc (void *arg)
{
    thrArg *self = (thrArg*)arg;    
    int j, i, k, status;
    
    for (i = 0; i < OUTLOOPS; i++) {
        status = pthread_barrier_wait (&barrier);
        if (status > 0)
            err_abort (status, "wait on barrier");
        //每个线程迭代 INLOOPS 次,对自己的内部数组arr 成员加上 自己的增量值
        for (j = 0; j < INLOOPS; j++)
            for (k = 0; k < NARR; k++)
                self->arr[k] += self->incr;
        //先执行完迭代的线程在此等待,直到最后一个到达
        status = pthread_barrier_wait (&barrier);
        if (status > 0)
            err_abort (status, "wait on barrier");
        //最后一个到达的线程,把所有线程的内部增量加1
        //此时其他先到的线程阻塞在第一次wait调用处,所以最后一个到达的线程
        //可以排他性地访问所有线程的内部状态,if 语句执行完后,跳到第一次wait处,
        //其他阻塞在第一次wait处的线程,得到释放,大家一块使用新的增量做计算
        if (status == PTHREAD_BARRIER_SERIAL_THREAD ) {
            int i;
            for (i = 0; i < NTHR; i++)
                thrs[i].incr += 1;
        }
    }
    return NULL;
}

int main (int arg, char *argv[])
{
    int i, j;
    int status;

    pthread_barrier_init (&barrier, NULL, NTHR);

    for (i = 0; i < NTHR; i++) {
        thrs[i].incr = i;
        for (j = 0; j < NARR; j++)
            thrs[i].arr[j] = j + 1;

        status = pthread_create (&thrs[i].tid,
            NULL, thrFunc, (void*)&thrs[i]);
        if (status != 0)
            err_abort (status, "create thread");
    }

    for (i = 0; i < NTHR; i++) {
        status = pthread_join (thrs[i].tid, NULL);
        if (status != 0)
            err_abort (status, "join thread");

        printf ("%02d: (%d) ", i, thrs[i].incr);

        for (j = 0; j < NARR; j++)
            printf ("%010u ", thrs[i].arr[j]);
        printf ("\n");
    }
    pthread_barrier_destroy (&barrier);
    return 0;
}

怎么用golang 来表达上述c 代码,需要实现pthread_barrier_t 等价语义的的 barrier 对象,可以使用golang 已有的mutex, cond

对象实现 barrier:

package main
import (
    "fmt"
    "sync"
)
type Barrier struct{
    lock  sync.Mutex
    cond  sync.Cond
    threshold  int    //总的等待个数
    count      int    //还剩多少没有到达barrier,即没有完成wait调用个数
    cycle      bool   //用于重初始化下一个wait 周期,
}
func NewBarrier(n  int) *Barrier{
    b := &Barrier{threshold: n, count: n} 
    b.cond.L = &b.lock
    return b
}
//last == true ,说明最有一个到达
func (b *Barrier)Wait()(last bool){
    b.lock.Lock()
    defer  b.lock.Unlock()
    cycle :=  b.cycle
    b.count--
    //最后一个到达负责,重初始化count 计数,cycle 变量翻转,
    if b.count == 0 {
       b.cycle  =  !b.cycle 
       b.count = b.threshold 
       b.cond.Broadcast()
       last = true
    }else{
      for cycle == b.cycle {
          b.cond.Wait()
      }
    }
    return
}
type thrArg struct{
   incr  int
   arr   [narr]int
}
var (
    thrs  [nthr]thrArg
    wg   sync.WaitGroup
    barrier = NewBarrier(nthr)
)
const (
    outloops = 10
    inloops  = 1000
    nthr  = 5
    narr  = 6
)

func thrFunc(arg  *thrArg){
    defer wg.Done()
    for i := 0; i < outloops; i++{
        barrier.Wait()
        for j := 0; j < inloops; j++{
            for k:= 0; k < narr; k++{
                arg.arr[k] += arg.incr
            }
        }
        if barrier.Wait() {
            for i := 0; i < nthr; i++{
                thrs[i].incr += 1
            }
        }
    }
}

func  main(){
    for i:= 0; i < nthr; i++{
        thrs[i].incr =  i
        for j := 0; j < narr; j++{
            thrs[i].arr[j] = j + 1
        }
        wg.Add(1)
        go thrFunc(&thrs[i])
    }
    wg.Wait()
    //所有goroutine完成,main goroutine,检查最后的结果
    for i := 0; i < nthr; i++{
        fmt.Printf("%02d: (%d) ", i, thrs[i].incr)
        for j := 0; j < narr; j++{
            fmt.Printf ("%010d ", thrs[i].arr[j]);
        }
        fmt.Println()
    }
}














你可能感兴趣的:(c,linux,线程,pthread,golang)