0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

WaitGroupの内部処理に迫ってみよう!

Posted at

はじめに

みなさん、Go言語触ってますか?
私は最近になってようやくGo言語に触れる機会が増えて大変幸せになっています。

ところで、Go言語って実はオープンソースで、GitHub上で公開されているって知ってましたか??しかもそのソースコードの大半がGo言語で書かれているため、非常に読みやすくなっています。

golang/goリポジトリ

そこで今回は、Goの内部に迫ってみようの第一回として、 WaitGroup を深読みしていきます。

WaitGroupとは

WaitGroupは以下のような使い方をします。

package main

import (
	"fmt"
	"sync"
	"time"
)

func main() {
	var wg sync.WaitGroup

	for range 10 {
		wg.Add(1)
		go func() {
			defer wg.Done()
			time.Sleep(1 * time.Second)
		}()
	}

	wg.Wait()
	fmt.Println("done")
}

WaitGroupは構造体で、その構造体に紐づけられたメソッドがいくつか存在します。イメージとして、 AddメソッドはWaitGroupの内部的なカウンターを1増やし、 Doneメソッドはそのカウンターを1減らします。 Waitメソッドで内部カウンターが0になるまでブロッキングし、0になったら以降の処理を実行します。

WaitGroup構造体

WaitGroup構造体とそのメソッドは、 syncパッケージの waitgroup.goに書かれています。

WaitGroup構造体

// A WaitGroup must not be copied after first use.
type WaitGroup struct {
	noCopy noCopy

	// Bits (high to low):
	//   bits[0:32]  counter
	//   bits[32]    flag: synctest bubble membership
	//   bits[33:64] wait count
	state atomic.Uint64
	sema  uint32
}

一番ミソとなるのが state変数。 atomic.Uint64というのは後述するとして、こいつのビット長は64ビットです。そして state変数の上にあるコメントを読むと、どうも上位32ビットまでは wait count、33ビット目は synctest bubbleのflag、下位31ビットは counterらしいです。
あとは謎に semaという変数がありますね。

atomic.Uint64

atomicパッケージ内の type.goにある構造体です。

atomic.Uint64構造体

// A Uint64 is an atomic uint64. The zero value is zero.
//
// Uint64 must not be copied after first use.
type Uint64 struct {
	_ noCopy
	_ align64
	v uint64
}

Add(delta uint64)

さて、ここで特に重要となるのがこのメソッド。

atomic.Uint64のAddメソッド

func (x *Uint64) Add(delta uint64) (new uint64) { return AddUint64(&x.v, delta) }

とは言っても別段不思議なことはしてませんが、多分 vポインターに deltaを直接加算する、いわゆる即値演算をしているようです。

CompareAndSwap(old, new uint64)

atomic.Uint64のCompareAndSwapメソッド

// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uint64) CompareAndSwap(old, new uint64) (swapped bool) {
	return CompareAndSwapUint64(&x.v, old, new)
}

私はそこまで低レイヤーに詳しくないので、調べてみました。What is "compare-and-swap"?

そしたら以下のqiita記事が参考になりました。

CASを使ったロックフリー(Lock-free)共有カウンタの作成方法 | qiita

どうも oldという値が xのアドレスにある値と等しければ、 newxのアドレスに代入するみたいです。戻り値の swappedは、値が交換できたかどうかを返すみたいですね。

Add(delta int)メソッド

ついに本丸へ突入!

コード内でいい感じに書かれている部分を抜粋しました。

WaitGroupのAddメソッド

state := wg.state.Add(uint64(delta) << 32) /* ① */

v := int32(state >> 32) /* ② */
w := uint32(state & 0x7fffffff)
/* ③ */
if v < 0 {
	panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
	return
}

if wg.state.Load() != state {
	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
wg.state.Store(0)
for ; w != 0; w-- {
	runtime_Semrelease(&wg.sema, false, 0)
}

まず①の以下のコード部分。

state := wg.state.Add(uint64(delta) << 32) /* ① */

これは deltaを32ビット左シフトして state に加算しています。先ほども解説した通り、 stateは上位32ビットを wait countとして使うため、それを足しているようです。 wg.state.Add()メソッドは先ほど紹介した atomic.Uinit64Addメソッドになります。

その次の②の部分。

v := int32(state >> 32) /* ② */
w := uint32(state & 0x7fffffff)

vは先ほど加算し終わった wait countを取り出しており、 wcounterを取り出しています。最後が 7なのは、 bits[32]がフラグとなっているため、その部分は取得しないようにしているからです。

次に③の部分です。

if v < 0 {
	panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
	return
}

if v < 0というのは純粋に wait countが マイナスになることによって生じます。実際、以下のコードを書くと panicで書かれているメッセージと同じエラー文が出力されます。

package main

import (
	"fmt"
	"sync"
	"time"
)

func main() {
	var wg sync.WaitGroup

	wg.Add(-1) // あえて-1している

	for range 10 {
		wg.Add(1)
		go func() {
			defer wg.Done()
			time.Sleep(1 * time.Second)
		}()
	}

	wg.Wait()
	fmt.Println("done")
}

次の w != 0 && delta > 0 && v == int32(delta)という条件式はなんでしょうか? v == int32(delta)というのは、詰まるところ最初の Addならばということになります。加算した後の vと 引数の deltaが一緒になるには、 vが0、すなわち最初の時だけだからです。では最後の w != 0はなんでしょうか?まず w、つまり counterが持つ役割を知る必要がありそうですが、これについては後述します。

returnは 早期リターンなので、 それより以降の処理は wait countが 0だった時になります。

if wg.state.Load() != state {
	panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
wg.state.Store(0)
for ; w != 0; w-- {
	runtime_Semrelease(&wg.sema, false, 0)
}

wg.state.Store(0)wait countをリセットし、 runtime_Semreleaseで何かを処理しています。 runtime_Semreleasew分だけ繰り返しています。

セマフォ

ここで出てきた runtime_Semreleaseはどう言った役割をしているのでしょうか?
結論から言うと、セマフォの役割を果たしています。セマフォとは排他制御のための仕組みの一つです。 runtime_Semreleaseは同パッケージ内の runtime.goで定義されています。

runtime_Semrelease関数

// Semrelease atomically increments *s and notifies a waiting goroutine
// if one is blocked in Semacquire.
// It is intended as a simple wakeup primitive for use by the synchronization
// library and should not be used directly.
// If handoff is true, pass count directly to the first waiter.
// skipframes is the number of frames to omit during tracing, counting from
// runtime_Semrelease's caller.
func runtime_Semrelease(s *uint32, handoff bool, skipframes int)

冒頭に書かれている通り、 Semrelease*s に加算をし、ブロッキングされているゴルーチンに通知します。 increments と言ってるぐらいですから、 +1されるのでしょう。なので wの分だけ繰り返していると言うことですね。
そして sは、 wgでいうところの wg.semaに当たります。

Done()メソッド

WaitGroupのDoneメソッド

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

個人的にこれが一番衝撃的だったのですが、なんと Doneメソッドは内部的にはただ Add(-1)しているだけなんですね。こんなシンプルだとは思いませんでした...

Wait()メソッド

WaitGroupのWaitメソッド

関係のある箇所だけを抜粋しました。

// Wait blocks until the [WaitGroup] task counter is zero.
func (wg *WaitGroup) Wait() {
	for {
		state := wg.state.Load()
		v := int32(state >> 32)
		w := uint32(state & 0x7fffffff)
		if v == 0 {
			return
		}
		// Increment waiters count.
		if wg.state.CompareAndSwap(state, state+1) {
			runtime_SemacquireWaitGroup(&wg.sema, synctestDurable)
			isReset := wg.state.Load() != 0
			if isReset {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}

state変数に wg.stateの値を代入し、 vwにそれぞれ wait countcounterを代入しています。もし wait countが0なら早期リターンされます。
次の if wg.state.CompareAndSwap(state, state+1)ですが、これは stateを1加算しています。下位31ビットは counterになりますので、これは counter 部分に加算して wg.stateに登録していることになります。
次の runtime_SemacauireWaitGroupは、詳しくは後述しますが、基本的には wの値によってブロッキングされている状態と思ってください。 wg.semaが特定の状態になると、その下の行が実行されます。 isReset := wg.state.Load() != 0は、wait countcounterflagの全ての値が0かどうかをみています。 0ではない時は panicを起こし、それ以外は returnで脱出しています。

runtime_SemacquireWaitGroupの正体

こちらも同パッケージ内の runtime.goにあります。

runtime_SemacquireWaitGroup関数

// Semacquire waits until *s > 0 and then atomically decrements it.
// It is intended as a simple sleep primitive for use by the synchronization
// library and should not be used directly.
func runtime_Semacquire(s *uint32)

// SemacquireWaitGroup is like Semacquire, but for WaitGroup.Wait.
func runtime_SemacquireWaitGroup(s *uint32, synctestDurable bool)

コメントアウトを見ると、 runtime_SemacquireWaitGroupに対する説明として

SemacquireWaitGroupSemaquireみたいなものですが、 WaitGroup.Wait専用です

とあります。では、 Semaquire 、ここで言うところの runtime_Semaquireは何をしているのかというと、

Semaquire*sが0より大きくなるまで待って、その後 *sを減らします。

とあります。

ここで、 runtime_Semreleaseをおさらいしましょう。この関数は、 *sをインクリメントし、待機中のゴルーチンがあればそのゴルーチンに通知するというものでした。そしてこの関数は wait countが0の時に処理されます。
つまり、 wait countが0、すなわち wg.Add(1)された分だけ wg.Done()されたら、 runtime_Semreleasew分だけインクリメントをし、待機中のゴルーチンに通知します。通知されたゴルーチンは runtime_SemacquireWaitGroupによってデクリメントします。待機中のゴルーチンは wの数に等しいですから、 runtime_SemacquireWaitGroupが全てのゴルーチンで実行されたら、セマフォは0になります。

Go(f func())メソッド

WaitGroupのGoメソッド

func (wg *WaitGroup) Go(f func()) {
	wg.Add(1)
	go func() {
		defer func() {
			if x := recover(); x != nil {
				// f panicked, which will be fatal because
				// this is a new goroutine.
				//
				// Calling Done will unblock Wait in the main goroutine,
				// allowing it to race with the fatal panic and
				// possibly even exit the process (os.Exit(0))
				// before the panic completes.
				//
				// This is almost certainly undesirable,
				// so instead avoid calling Done and simply panic.
				panic(x)
			}

			// f completed normally, or abruptly using goexit.
			// Either way, decrement the semaphore.
			wg.Done()
		}()
		f()
	}()
}

内部的には最初にお見せしたサンプルコードと似たようなことをしています。つまり Goメソッドを使用すると以下のように書き換えることができます。

package main

import (
	"fmt"
	"sync"
	"time"
)

func main() {
	var wg sync.WaitGroup

	for range 10 {
		wg.Go(func() {
			time.Sleep(1 * time.Second)
		})
	}

	wg.Wait()
	fmt.Println("done")
}

非常にシンプルになりましたね!基本的に wg.Add(1)wg.Done()はセットでコーディングされるので、それをシンプルにまとめるために Goメソッドが生み出された感じがします。

最後に...更なる深淵へ.....

さて、ここまで WaitGroupという割と馴染み深い構造体とそのメソッドの内部処理について解説しました。
しかしここでみなさん、ある疑問が生まれると思います。そう、 待機状態ってどのようにされて、どのように解除されるのかという疑問。

その疑問についてはまた別の記事にて.......(多分)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?