天天看点

Go语言之sync包 WaitGroup的使用

WaitGroup 是什么以及它能为我们解决什么问题?

WaitGroup

在go语言中,用于线程同步,单从字面意思理解,

wait

等待的意思,

group

组、团队的意思,

WaitGroup

就是指等待一组,等待一个系列执行完成后才会继续向下执行。

正常情况下,

goroutine

的结束过程是不可控制的,我们可以保证的只有

main goroutine

的终止。

这时候可以借助

sync

包的

WaitGroup

来判断

goroutine

是否完成。

WaitGroup介绍

WatiGroup

sync

包中的一个

struct

类型,用来收集需要等待执行完成的

goroutine

。下面是它的定义:

// WaitGroup用于等待一组线程的结束。
// 父线程调用Add方法来设定应等待的线程的数量。
// 每个被等待的线程在结束时应调用Done方法。同时,主线程里可以调用Wait方法阻塞至所有线程结束。
type WaitGroup struct {
    // 包含隐藏或非导出字段
}

// Add方法向内部计数加上delta,delta可以是负数;
// 如果内部计数器变为0,Wait方法阻塞等待的所有线程都会释放,如果计数器小于0,方法panic。
// 注意Add加上正数的调用应在Wait之前,否则Wait可能只会等待很少的线程。
// 一般来说本方法应在创建新的线程或者其他应等待的事件之前调用。
func (wg *WaitGroup) Add(delta int)

// Done方法减少WaitGroup计数器的值,应在线程的最后执行。
func (wg *WaitGroup) Done()

// Wait方法阻塞直到WaitGroup计数器减为0。
func (wg *WaitGroup) Wait()
           

它有3个方法:

    Add():每次激活想要被等待完成的

goroutine

之前,先调用Add(),用来设置或添加要等待完成的

goroutine

数量

        例如Add(2)或者两次调用Add(1)都会设置等待计数器的值为2,表示要等待2个

goroutine

完成

    Done():每次需要等待的

goroutine

在真正完成之前,应该调用该方法来人为表示

goroutine

完成了,该方法会对等待计数器减1

    Wait():在等待计数器减为0之前,Wait()会一直阻塞当前的

goroutine

    也就是说,Add()用来增加要等待的

goroutine

的数量,Done()用来表示

goroutine

已经完成了,减少一次计数器,Wait()用来等待所有需要等待的

goroutine

完成。

示例一

package main

import (
    "fmt"
    "sync"
    "time"
)

// 每个协程都会运行该函数。
// 注意,WaitGroup 必须通过指针传递给函数。
func worker(id int, wg *sync.WaitGroup) {
    fmt.Printf("Worker %d starting\n", id)

    // 睡眠一秒钟,以此来模拟耗时的任务。
    time.Sleep(time.Second)
    fmt.Printf("Worker %d done\n", id)

    // 通知 WaitGroup ,当前协程的工作已经完成。
    wg.Done()
}

func main() {

    // 这个 WaitGroup 被用于等待该函数开启的所有协程。
    var wg sync.WaitGroup

    // 开启几个协程,并为其递增 WaitGroup 的计数器。
    for i := 1; i <= 5; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }

    // 阻塞,直到 WaitGroup 计数器恢复为 0,即所有协程的工作都已经完成。
    wg.Wait()
}
           

main中开启了5个协程,开启协程之前都先调用了Add()方法增加了一个需要等待

goroutine

计数。每个

goroutine

都运行worker()函数,这个函数执行完成后调用Done()方法通知

WaitGroup

表示当前协程的完成。

有一点需要注意,worker()函数中使用了指针类型的

*sync.WaitGroup

作为参数,这里不能使用值类型的

sync.WaitGroup

作为参数,因为这意味着每个

goroutine

都拷贝一份wg,每个

goroutine

都使用自己的wg。这显然是不合理的,这5个协程应该共享一个wg,这样才能知道这几个协程都完成了。实际上,如果使用值类型的参数,

main goroutine

将会永久阻塞而导致产生死锁。

还有一点需要注意

Add

Done

函数一定要配对,否则可能发生死锁,所报的错误信息如下:

fatal error: all goroutines are asleep - deadlock!
           

运行:

go run waitgroups.go
Worker 5 starting
Worker 3 starting
Worker 4 starting
Worker 1 starting
Worker 2 starting
Worker 4 done
Worker 1 done
Worker 2 done
Worker 5 done
Worker 3 done
           

每次运行,各个协程开启和完成的时间可能是不同的。

示例二

在工作中使用时,等待一个协程组全部正确完成则结束;但其中一个协程发生错误,这时候就会阻塞了,不推荐这种用法。

这种场景就需要使用到通知机制,这时候可以使用

channel

来实现。

package main

import (
	"fmt"
	"sync"
	"time"
)


func main(){
	// 这个 WaitGroup 被用于等待该函数开启的所有协程。
	var wg sync.WaitGroup

	// Add()方法开启了3个等待的协程计数
	wg.Add(3)

        // 开启3个协程,用于工作处理
	go work1(&wg)
	go work2(&wg)
	go work3(&wg)

	// 阻塞,直到 WaitGroup 计数器恢复为 0,即所有协程的工作都已经完成。
	wg.Wait()
}

func work1(wg *sync.WaitGroup){
	fmt.Println("work1 starting")

	// 睡眠一秒钟,以此来模拟耗时的任务。
	time.Sleep(time.Second)
	fmt.Println("work1 done")

	// 通知 WaitGroup ,当前协程的工作已经完成。
	wg.Done()
}

func work2(wg *sync.WaitGroup){
	fmt.Println("work2 starting")

	// 睡眠一秒钟,以此来模拟耗时的任务。
	time.Sleep(time.Second)
	fmt.Println("work2 done")

	// 通知 WaitGroup ,当前协程的工作已经完成。
	wg.Done()
}

func work3(wg *sync.WaitGroup){
	fmt.Println("work3 starting")

	// 睡眠一秒钟,以此来模拟耗时的任务。
	time.Sleep(time.Second)
	fmt.Println("work3 done")

	// 通知 WaitGroup ,当前协程的工作已经完成。
	wg.Done()
}
           

源码分析

type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state, and the other 4 as storage
	// for the sema.
	state1 [3]uint32
}
           

WaitGroup

结构十分简单,由

nocopy

state1

两个字段组成,其中

nocopy

是用来防止复制的

type noCopy struct{}

// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}
           

由于嵌入了

nocopy

所以在执行

go vet

时如果检查到

WaitGroup

被复制了就会报错。这样可以一定程度上保证

WaitGroup

不被复制,对了直接

go run

是不会有错误的,所以我们代码

push

之前都会强制要求进行

lint

检查,在

ci/cd

阶段也需要先进行

lint

检查,避免出现这种类似的错误。

~/project/Go-000/Week03/blog/06_waitgroup/02 main*
❯ go run ./main.go

~/project/Go-000/Week03/blog/06_waitgroup/02 main*
❯ go vet .
# github.com/mohuishou/go-training/Week03/blog/06_waitgroup/02
./main.go:7:9: assignment copies lock value to wg2: sync.WaitGroup contains sync.noCopy
           

state1

的设计非常巧妙,这是一个是十二字节的数据,这里面主要包含两大块,

counter

占用了 8 字节用于计数,

sema

占用 4 字节用做信号量

为什么要这么搞呢?直接用两个字段一个表示

counter

,一个表示

sema

不行么?

不行,我们看看注释里面怎么写的。

// 64-bit value: high 32 bits are counter, low 32 bits are waiter count. > // 64-bit atomic operations require 64-bit alignment, but 32-bit > // compilers do not ensure it. So we allocate 12 bytes and then use > // the aligned 8 bytes in them as state, and the other 4 as storage > // for the sema.

这段话的关键点在于,在做 64 位的原子操作的时候必须要保证 64 位(8 字节)对齐,如果没有对齐的就会有问题,但是 32 位的编译器并不能保证 64 位对齐所以这里用一个 12 字节的 state1 字段来存储这两个状态,然后根据是否 8 字节对齐选择不同的保存方式。

Go语言之sync包 WaitGroup的使用

这个操作巧妙在哪里呢?

  • 如果是 64 位的机器那肯定是 8 字节对齐了的,所以是上面第一种方式
  • 如果在 32 位的机器上

    如果恰好 8 字节对齐了,那么也是第一种方式取前面的 8 字节数据

    如果是没有对齐,但是 32 位 4 字节是对齐了的,所以我们只需要后移四个字节,那么就 8 字节对齐了,所以是第二种方式

所以通过

sema

信号量这四个字节的位置不同,保证了

counter

这个字段无论在 32 位还是 64 为机器上都是 8 字节对齐的,后续做 64 位原子操作的时候就没问题了。

这个实现是在

state

方法实现的

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}
           

state

方法返回

counter

和信号量,通过

uintptr(unsafe.Pointer(&wg.state1))%8 == 0

来判断是否 8 字节对齐

Add

func (wg *WaitGroup) Add(delta int) {
    // 先从 state 当中把数据和信号量取出来
	statep, semap := wg.state()

    // 在 waiter 上加上 delta 值
	state := atomic.AddUint64(statep, uint64(delta)<<32)
    // 取出当前的 counter
	v := int32(state >> 32)
    // 取出当前的 waiter,正在等待 goroutine 数量
	w := uint32(state)

    // counter 不能为负数
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}

    // 这里属于防御性编程
    // w != 0 说明现在已经有 goroutine 在等待中,说明已经调用了 Wait() 方法
    // 这时候 delta > 0 && v == int32(delta) 说明在调用了 Wait() 方法之后又想加入新的等待者
    // 这种操作是不允许的
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
    // 如果当前没有人在等待就直接返回,并且 counter > 0
	if v > 0 || w == 0 {
		return
	}

    // 这里也是防御 主要避免并发调用 add 和 wait
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}

	// 唤醒所有 waiter,看到这里就回答了上面的问题了
	*statep = 0
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}
           

Wait

wait

主要就是等待其他的

goroutine

完事之后唤醒

func (wg *WaitGroup) Wait() {
	// 先从 state 当中把数据和信号量的地址取出来
    statep, semap := wg.state()

	for {
     	// 这里去除 counter 和 waiter 的数据
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32)
		w := uint32(state)

        // counter = 0 说明没有在等的,直接返回就行
        if v == 0 {
			// Counter is 0, no need to wait.
			return
		}

		// waiter + 1,调用一次就多一个等待者,然后休眠当前 goroutine 等待被唤醒
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			runtime_Semacquire(semap)
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}
           

Done

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}
           

总结

通过

WaitGroup

提供的三个函数:

Add

,

Done

Wait

,可以轻松实现等待某个协程或协程组完成的同步操作。但在使用时要注意:

  • WaitGroup

    可以用于一个

    goroutine

    等待多个

    goroutine

    干活完成,也可以多个

    goroutine

    等待一个

    goroutine

    干活完成,是一个多对多的关系

    多个等待一个的典型案例是 singleflight,这个在后面将微服务可用性的时候还会再讲到,感兴趣可以看看源码

  • Add(n>0)

    方法应该在启动

    goroutine

    之前调用,然后在

    goroution

    内部调用

    Done

    方法
  • WaitGroup

    必须在

    Wait

    方法返回之后才能再次使用
  • Done

    只是

    Add

    的简单封装,所以实际上是可以通过一次加一个比较大的值减少调用,或者达到快速唤醒的目的。
  • 协程函数要使用指针类型的

    *sync.WaitGroup

    作为参数,不能使用值类型的

    sync.WaitGroup

    作为参数
  • Add的数量和Done的调用数量必须相等,否则可能发生死锁
上一篇: 初识指针