LoginSignup
59
60

More than 5 years have passed since last update.

x/net/context の実装パターン

Last updated at Posted at 2015-08-07

モチベーション

例えば、外部のリソースにアクセスする時など、並列に実行する部分の処理を共通化したいことが多いと思う。特に既存のコードは逐次実行されるコードになっているが、局所的に並列化したい時などは、チャネル使って書き直すのはなかなかしんどかったりする。

そこで、x/net/context を使いつつ、並列処理を共通化する方法を考えてみた。

x/net/context

Google製の x/net/context パッケージというのがある。
使い方などは、以下のリンクが詳しい

とてもシンプルだけどよく出来ているパッケージで、名前の通り、コンテキストなので、共通の値を管理する目的に主眼が置かれているけど、もう一つ、チャネルのキャンセル処理のための機能が備わっていて、キャンセル処理が階層化されて管理されたり(親をキャンセルしたら、それが子に伝搬する)、何度でもキャンセルを呼んでもエラーにならない、など、細かいがgorutineのキャンセル処理を実装するときに使いやすい構造になっている。

実装

基本的なアイデアは、タスクを、 type ContextErrFunc func(ctx context.Context) error という型に統一して、シリアルに実行する、並列に実行する、というのを組み合わせることができるようにする。

package flow

import "golang.org/x/net/context"

// ContextErrFunc : 処理共通化のためのタスクの定義用
type ContextErrFunc func(ctx context.Context) error

// ContextSerial : ContextErrFunc を直列に実行して、エラーが途中で起こったらその時点でエラーを返す
func ContextSerial(fs ...ContextErrFunc) ContextErrFunc {
    return func(ctx context.Context) error {
        for _, f := range fs {
            if f == nil {
                continue
            }
            if err := f(ctx); err != nil {
                return err
            }
        }
        return nil
    }
}

// ContextParallel : ContextErrFunc を並列に実行して、エラーが途中で起こったらその時点でエラーを返す
func ContextParallel(fs ...ContextErrFunc) ContextErrFunc {
    return func(ctx context.Context) error {
        childCtx, cancelAll := context.WithCancel(ctx)
        defer cancelAll()

        doneCh := make(chan struct{}, len(fs))
        errCh := make(chan error, len(fs))
        recoverCh := make(chan interface{}, len(fs))

        for _, f := range fs {
            go func(_f ContextErrFunc) {
                defer func() {
                    r := recover()
                    if r != nil {
                        recoverCh <- r
                    }
                }()

                if _f == nil {
                    doneCh <- struct{}{}
                    return
                }

                if err := _f(childCtx); err != nil {
                    errCh <- err
                    return
                }
                doneCh <- struct{}{}
            }(f)
        }

        for i := 0; i < len(fs); i++ {
            select {
            case <-ctx.Done():
                return ctx.Err()
            case <-doneCh:
            case err := <-errCh:
                return err
            case r := <-recoverCh:
                panic(r)
            }
        }
        return nil
    }
}

使い方

例: googleの解説記事の httpDo

ContextErrFunc 型のタスクを返す関数を定義する

func httpDoTask(req *http.Request, f func(*http.Response, error) error) flow.ContextErrFunc {
    return func(ctx context.Context) error {
        // Run the HTTP request in a goroutine and pass the response to f.
        tr := &http.Transport{}
        client := &http.Client{Transport: tr}
        c := make(chan error, 1)
        go func() { c <- f(client.Do(req)) }()
        select {
        case <-ctx.Done():
            tr.CancelRequest(req)
            <-c // Wait for f to return.
            return ctx.Err()
        case err := <-c:
            return err
        }
    }
}

func main() {
    req1, _ := http.NewRequest("GET", "http://google.com", nil)
    req2, _ := http.NewRequest("GET", "http://yahoo.com", nil)
    req3, _ := http.NewRequest("GET", "http://microsoft.com", nil)

    // 全体で5秒でタイムアウト
    tc, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()

    var res1, res2, res3 *http.Response

    if err := flow.ContextSerial(
        // 2つ並列に実行してから
        flow.ContextParallel(
            httpDoTask(req1, func(ares1 *http.Response, err error) error {
                res1 = ares1
                return err
            }),
            httpDoTask(req2, func(ares2 *http.Response, err error) error {
                res2 = ares2
                return err
            }),
        ),
        // 3つめを実行
        httpDoTask(req3, func(ares3 *http.Response, err error) error {
            res3 = ares3
            return err
        }),
    )(tc); err != nil {
        log.Fatal(err)
    }

    fmt.Println(res1.StatusCode, res2.StatusCode, res3.StatusCode)
}

まとめと、何をcontextに入れるか

context という名前だけあって、例えばサーバーのハンドラの中で使うときには、そのリクエストのコンテキストにかかわる部分(ログインしているユーザーとか、DBの接続先とか) のみを入れるようにして、その他は呼ぶ関数の引数にしたほうが混乱が少ないコードになる。

例に書いたように、引数を取るような処理は、type ContextErrFunc func(ctx context.Context) error という型の1つのタスクを返す関数として定義することで、全体のフロー制御を共通化しつつ、キャンセル処理などをうまく扱うコードにできそう。

59
60
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
59
60