ゴルーチンを使って並行処理を書きたいが同時実行数を制御したいという場面は多々ある。メモリ・CPUなどリソースは限られてますから。
バッファ付きを使ったゴルーチン数の制御
バッファ付きチャネルの次の特性を使って、計数セマフォとして使うことで簡単に実現できる。
- 空きがあれば送信側は待たされない
- 空きがなければ送信側は待たされる(ブロックされる)
package main
import (
"fmt"
"sync"
"time"
)
const concurrencyNum = 3 // 同時実行数
func main() {
tasks := []string{"タスクA", "タスクB", "タスクC", "タスクD", "タスクE"}
sem := make(chan struct{}, concurrencyNum)
var wg sync.WaitGroup
for _, t := range tasks {
sem <- struct{}{} //
wg.Add(1)
go func(t string) {
defer func() {
wg.Done()
<-sem // チャネルから値を送信して空きを作る(解放)
}()
doTask(t)
}(t)
}
wg.Wait()
fmt.Println("おわり")
}
func doTask(name string) {
fmt.Printf("[%s] 実行中\n", name)
time.Sleep(300 * time.Millisecond) // 時間のかかる処理
fmt.Printf("[%s] 完了\n", name)
}
実行結果
[タスクC] 実行中
[タスクA] 実行中
[タスクB] 実行中 // ← タスクA,B,Cの3つが起動
[タスクB] 完了
[タスクA] 完了
[タスクD] 実行中 // ← 空きができたのでDが起動
[タスクC] 完了
[タスクE] 実行中 // ← 空きができたのでEが起動
[タスクD] 完了
[タスクE] 完了
おわり
もし、各ゴルーチンでエラーが起きたときにメインゴルーチン側でなにかするとか考慮するとこれだと不十分。
そんなときはgolang.org/x/sync
で用意されている errgroup
パッケージを使うと良い。
errgroupパッケージの例
errgroupパッケージ
準標準ライブラリっていうのかな?
いつからか知らないけどerrgroupパッケージに同時実行数を制限するSetLimitが実装されているので、これを使う。
SetLimit limits the number of active goroutines in this group to at most n. A negative value indicates no limit.
Any subsequent call to the Go method will block until it can add an active goroutine without exceeding the configured limit.
The limit must not be modified while any goroutines in the group are active.
ちなみに、errgroupの実装をチラ見したところバッファ付きチャネルを使っており、上述したロジックに近いかたちで制御していた。
(SetLimitで指定した数のバッファ付きチャネルを作る)
package main
import (
"context"
"errors"
"fmt"
"time"
"golang.org/x/sync/errgroup"
)
const maxConcurrency = 3
func main() {
ctx := context.Background()
tasks := []string{"タスクA", "タスクB", "タスクC", "タスクD", "タスクE"}
eg, _ := errgroup.WithContext(ctx)
eg.SetLimit(maxConcurrency)
for _, t := range tasks {
tmp := t
eg.Go(func() error {
return doTask(tmp)
})
}
if err := eg.Wait(); err != nil {
fmt.Printf("err: %s\n", err)
}
fmt.Println("おわり")
}
func doTask(name string) error {
fmt.Printf("[%s] 実行中\n", name)
time.Sleep(300 * time.Millisecond) // 時間のかかる処理
if name == "タスクC" { // 格好悪いけど、無理やりエラーを起こす
return errors.New("エラーが起きた!!")
}
fmt.Printf("[%s] 完了\n", name)
return nil
}
結果
[タスクC] 実行中
[タスクA] 実行中
[タスクB] 実行中
[タスクD] 実行中
[タスクA] 完了
[タスクE] 実行中
[タスクB] 完了
[タスクE] 完了
[タスクD] 完了
err: エラーが起きた!!
おわり