天天看點

Go語言之sync包 WaitGroup的使用

WaitGroup 是什麼以及它能為我們解決什麼問題?

WaitGroup

在go語言中,用于線程同步,單從字面意思了解,

wait

等待的意思,

group

組、團隊的意思,

WaitGroup

就是指等待一組,等待一個系列執行完成後才會繼續向下執行。

正常情況下,

goroutine

的結束過程是不可控制的,我們可以保證的隻有

main goroutine

的終止。

這時候可以借助

sync

包的

WaitGroup

來判斷

goroutine

是否完成。

WaitGroup介紹

WatiGroup

sync

包中的一個

struct

類型,用來收集需要等待執行完成的

goroutine

。下面是它的定義:

// WaitGroup用于等待一組線程的結束。
// 父線程調用Add方法來設定應等待的線程的數量。
// 每個被等待的線程在結束時應調用Done方法。同時,主線程裡可以調用Wait方法阻塞至所有線程結束。
type WaitGroup struct {
    // 包含隐藏或非導出字段
}

// Add方法向内部計數加上delta,delta可以是負數;
// 如果内部計數器變為0,Wait方法阻塞等待的所有線程都會釋放,如果計數器小于0,方法panic。
// 注意Add加上正數的調用應在Wait之前,否則Wait可能隻會等待很少的線程。
// 一般來說本方法應在建立新的線程或者其他應等待的事件之前調用。
func (wg *WaitGroup) Add(delta int)

// Done方法減少WaitGroup計數器的值,應線上程的最後執行。
func (wg *WaitGroup) Done()

// Wait方法阻塞直到WaitGroup計數器減為0。
func (wg *WaitGroup) Wait()
           

它有3個方法:

    Add():每次激活想要被等待完成的

goroutine

之前,先調用Add(),用來設定或添加要等待完成的

goroutine

數量

        例如Add(2)或者兩次調用Add(1)都會設定等待計數器的值為2,表示要等待2個

goroutine

完成

    Done():每次需要等待的

goroutine

在真正完成之前,應該調用該方法來人為表示

goroutine

完成了,該方法會對等待計數器減1

    Wait():在等待計數器減為0之前,Wait()會一直阻塞目前的

goroutine

    也就是說,Add()用來增加要等待的

goroutine

的數量,Done()用來表示

goroutine

已經完成了,減少一次計數器,Wait()用來等待所有需要等待的

goroutine

完成。

示例一

package main

import (
    "fmt"
    "sync"
    "time"
)

// 每個協程都會運作該函數。
// 注意,WaitGroup 必須通過指針傳遞給函數。
func worker(id int, wg *sync.WaitGroup) {
    fmt.Printf("Worker %d starting\n", id)

    // 睡眠一秒鐘,以此來模拟耗時的任務。
    time.Sleep(time.Second)
    fmt.Printf("Worker %d done\n", id)

    // 通知 WaitGroup ,目前協程的工作已經完成。
    wg.Done()
}

func main() {

    // 這個 WaitGroup 被用于等待該函數開啟的所有協程。
    var wg sync.WaitGroup

    // 開啟幾個協程,并為其遞增 WaitGroup 的計數器。
    for i := 1; i <= 5; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }

    // 阻塞,直到 WaitGroup 計數器恢複為 0,即所有協程的工作都已經完成。
    wg.Wait()
}
           

main中開啟了5個協程,開啟協程之前都先調用了Add()方法增加了一個需要等待

goroutine

計數。每個

goroutine

都運作worker()函數,這個函數執行完成後調用Done()方法通知

WaitGroup

表示目前協程的完成。

有一點需要注意,worker()函數中使用了指針類型的

*sync.WaitGroup

作為參數,這裡不能使用值類型的

sync.WaitGroup

作為參數,因為這意味着每個

goroutine

都拷貝一份wg,每個

goroutine

都使用自己的wg。這顯然是不合理的,這5個協程應該共享一個wg,這樣才能知道這幾個協程都完成了。實際上,如果使用值類型的參數,

main goroutine

将會永久阻塞而導緻産生死鎖。

還有一點需要注意

Add

Done

函數一定要配對,否則可能發生死鎖,所報的錯誤資訊如下:

fatal error: all goroutines are asleep - deadlock!
           

運作:

go run waitgroups.go
Worker 5 starting
Worker 3 starting
Worker 4 starting
Worker 1 starting
Worker 2 starting
Worker 4 done
Worker 1 done
Worker 2 done
Worker 5 done
Worker 3 done
           

每次運作,各個協程開啟和完成的時間可能是不同的。

示例二

在工作中使用時,等待一個協程組全部正确完成則結束;但其中一個協程發生錯誤,這時候就會阻塞了,不推薦這種用法。

這種場景就需要使用到通知機制,這時候可以使用

channel

來實作。

package main

import (
	"fmt"
	"sync"
	"time"
)


func main(){
	// 這個 WaitGroup 被用于等待該函數開啟的所有協程。
	var wg sync.WaitGroup

	// Add()方法開啟了3個等待的協程計數
	wg.Add(3)

        // 開啟3個協程,用于工作處理
	go work1(&wg)
	go work2(&wg)
	go work3(&wg)

	// 阻塞,直到 WaitGroup 計數器恢複為 0,即所有協程的工作都已經完成。
	wg.Wait()
}

func work1(wg *sync.WaitGroup){
	fmt.Println("work1 starting")

	// 睡眠一秒鐘,以此來模拟耗時的任務。
	time.Sleep(time.Second)
	fmt.Println("work1 done")

	// 通知 WaitGroup ,目前協程的工作已經完成。
	wg.Done()
}

func work2(wg *sync.WaitGroup){
	fmt.Println("work2 starting")

	// 睡眠一秒鐘,以此來模拟耗時的任務。
	time.Sleep(time.Second)
	fmt.Println("work2 done")

	// 通知 WaitGroup ,目前協程的工作已經完成。
	wg.Done()
}

func work3(wg *sync.WaitGroup){
	fmt.Println("work3 starting")

	// 睡眠一秒鐘,以此來模拟耗時的任務。
	time.Sleep(time.Second)
	fmt.Println("work3 done")

	// 通知 WaitGroup ,目前協程的工作已經完成。
	wg.Done()
}
           

源碼分析

type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state, and the other 4 as storage
	// for the sema.
	state1 [3]uint32
}
           

WaitGroup

結構十分簡單,由

nocopy

state1

兩個字段組成,其中

nocopy

是用來防止複制的

type noCopy struct{}

// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}
           

由于嵌入了

nocopy

是以在執行

go vet

時如果檢查到

WaitGroup

被複制了就會報錯。這樣可以一定程度上保證

WaitGroup

不被複制,對了直接

go run

是不會有錯誤的,是以我們代碼

push

之前都會強制要求進行

lint

檢查,在

ci/cd

階段也需要先進行

lint

檢查,避免出現這種類似的錯誤。

~/project/Go-000/Week03/blog/06_waitgroup/02 main*
❯ go run ./main.go

~/project/Go-000/Week03/blog/06_waitgroup/02 main*
❯ go vet .
# github.com/mohuishou/go-training/Week03/blog/06_waitgroup/02
./main.go:7:9: assignment copies lock value to wg2: sync.WaitGroup contains sync.noCopy
           

state1

的設計非常巧妙,這是一個是十二位元組的資料,這裡面主要包含兩大塊,

counter

占用了 8 位元組用于計數,

sema

占用 4 位元組用做信号量

為什麼要這麼搞呢?直接用兩個字段一個表示

counter

,一個表示

sema

不行麼?

不行,我們看看注釋裡面怎麼寫的。

// 64-bit value: high 32 bits are counter, low 32 bits are waiter count. > // 64-bit atomic operations require 64-bit alignment, but 32-bit > // compilers do not ensure it. So we allocate 12 bytes and then use > // the aligned 8 bytes in them as state, and the other 4 as storage > // for the sema.

這段話的關鍵點在于,在做 64 位的原子操作的時候必須要保證 64 位(8 位元組)對齊,如果沒有對齊的就會有問題,但是 32 位的編譯器并不能保證 64 位對齊是以這裡用一個 12 位元組的 state1 字段來存儲這兩個狀态,然後根據是否 8 位元組對齊選擇不同的儲存方式。

Go語言之sync包 WaitGroup的使用

這個操作巧妙在哪裡呢?

  • 如果是 64 位的機器那肯定是 8 位元組對齊了的,是以是上面第一種方式
  • 如果在 32 位的機器上

    如果恰好 8 位元組對齊了,那麼也是第一種方式取前面的 8 位元組資料

    如果是沒有對齊,但是 32 位 4 位元組是對齊了的,是以我們隻需要後移四個位元組,那麼就 8 位元組對齊了,是以是第二種方式

是以通過

sema

信号量這四個位元組的位置不同,保證了

counter

這個字段無論在 32 位還是 64 為機器上都是 8 位元組對齊的,後續做 64 位原子操作的時候就沒問題了。

這個實作是在

state

方法實作的

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}
           

state

方法傳回

counter

和信号量,通過

uintptr(unsafe.Pointer(&wg.state1))%8 == 0

來判斷是否 8 位元組對齊

Add

func (wg *WaitGroup) Add(delta int) {
    // 先從 state 當中把資料和信号量取出來
	statep, semap := wg.state()

    // 在 waiter 上加上 delta 值
	state := atomic.AddUint64(statep, uint64(delta)<<32)
    // 取出目前的 counter
	v := int32(state >> 32)
    // 取出目前的 waiter,正在等待 goroutine 數量
	w := uint32(state)

    // counter 不能為負數
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}

    // 這裡屬于防禦性程式設計
    // w != 0 說明現在已經有 goroutine 在等待中,說明已經調用了 Wait() 方法
    // 這時候 delta > 0 && v == int32(delta) 說明在調用了 Wait() 方法之後又想加入新的等待者
    // 這種操作是不允許的
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
    // 如果目前沒有人在等待就直接傳回,并且 counter > 0
	if v > 0 || w == 0 {
		return
	}

    // 這裡也是防禦 主要避免并發調用 add 和 wait
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}

	// 喚醒所有 waiter,看到這裡就回答了上面的問題了
	*statep = 0
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}
           

Wait

wait

主要就是等待其他的

goroutine

完事之後喚醒

func (wg *WaitGroup) Wait() {
	// 先從 state 當中把資料和信号量的位址取出來
    statep, semap := wg.state()

	for {
     	// 這裡去除 counter 和 waiter 的資料
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32)
		w := uint32(state)

        // counter = 0 說明沒有在等的,直接傳回就行
        if v == 0 {
			// Counter is 0, no need to wait.
			return
		}

		// waiter + 1,調用一次就多一個等待者,然後休眠目前 goroutine 等待被喚醒
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			runtime_Semacquire(semap)
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}
           

Done

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}
           

總結

通過

WaitGroup

提供的三個函數:

Add

,

Done

Wait

,可以輕松實作等待某個協程或協程組完成的同步操作。但在使用時要注意:

  • WaitGroup

    可以用于一個

    goroutine

    等待多個

    goroutine

    幹活完成,也可以多個

    goroutine

    等待一個

    goroutine

    幹活完成,是一個多對多的關系

    多個等待一個的典型案例是 singleflight,這個在後面将微服務可用性的時候還會再講到,感興趣可以看看源碼

  • Add(n>0)

    方法應該在啟動

    goroutine

    之前調用,然後在

    goroution

    内部調用

    Done

    方法
  • WaitGroup

    必須在

    Wait

    方法傳回之後才能再次使用
  • Done

    隻是

    Add

    的簡單封裝,是以實際上是可以通過一次加一個比較大的值減少調用,或者達到快速喚醒的目的。
  • 協程函數要使用指針類型的

    *sync.WaitGroup

    作為參數,不能使用值類型的

    sync.WaitGroup

    作為參數
  • Add的數量和Done的調用數量必須相等,否則可能發生死鎖
上一篇: 初識指針