LoginSignup
47
40

More than 5 years have passed since last update.

Goでお手軽に行列の積を爆速並列計算

Last updated at Posted at 2015-11-07

TL;DR

  • Go言語で行列の積を並列計算する実装例を示した(たぶんよりよい書き方あるので誰か教えて)
  • ローカルのマシンとAWSのc4.8xlargeで実験すると、それなりに爆速になった

まえがき

競技プログラミングにおいて、行列累乗の計算が要求されることがよくあります。典型的な例は、複数の項の漸化式を行列の形に直すものでしょう。

行列の累乗の計算量は n を行列のサイズ、P を乗数とすると、繰り返し二乗法を用いることにより O(n^3logP) です。これは、行列同士の掛け算の部分が O(n^3) でボトルネックになっています。今回はこれを Goを用いて並列処理します。

Goの並列処理の仕組み

一言でまとめると、Goroutine(ゴルーチンと読む)という存在が複数のスレッド上で走っていて、そいつらは channelというものを介して相互に値をやり取りできます。

Goroutineを生成するには go構文を用いて以下のように書きます。

goroutine.go
go func() {
  // (何かしらの処理)
}()

定義してある関数を呼び出してもOKです。

goroutine.go
func someFunction(somearg int) {
  // (何かしらの処理)
}

go someFunction(114514)

これで // (何かしらの処理) が別の Goroutineで走ります。これらは処理をブロックせずに、Goroutineを呼んだ側はそのまま走り続けます。

また、channel を使うと以下のように値を送受信できます。

channel.go
ch := make(chan int)

// 値を投げる
go func() {
    ch <- 114514
}()

// 値をもらう
val := <- ch
// val = 114514

channel から値をもらう方はデータが来るまで、値を投げる方は受け側でもらう準備ができるまで、処理をブロックします。(channelのバッファが0の場合) これを利用すると go で呼び出した Goroutine たちがすべて終わるまで待つことができます。

blocking.go
func someFunction(i int, ch chan int) {
    // (何かしらの処理)

    // 処理が終わったことを channel に物を詰めることで教える
    ch <- 1
}

func main() {
    ch := make(chan int)

    // Goroutineを10個生成し、並列処理
    for i := 0 ; i < 10 ; i ++ {
        go someFunction(i, ch)
    }

    // Goroutineがすべて終わるのを待つ
    for i := 0 ; i < 10 ; i ++ {
        <- ch
    }
}

行列の掛け算の実装

直列版を書く

まずは愚直に、2次元スライスで表された行列を掛け合わせる関数を書きましょう。

matrix.go
type Matrix [][]int

func mul(a, b *Matrix) Matrix {
    ar := len(*a)
    ac := len((*a)[0])
    br := len(*b)
    bc := len((*b)[0])

    // 縦横のサイズが合わない場合
    if ac != br {
        panic("wrong matrix type")
    }

    // ここが O(n^3) になってる
    c := make(Matrix, ar)
    for i := 0 ; i < ar ; i++ {
        c[i] = make([]int, bc)
        for j := 0 ; j < bc ; j++ {
            for k := 0 ; k < ac ; k++ {
                c[i][j] += (*a)[i][k] * (*b)[k][j]
            }
        }
    }
    return c
}

処理の本体は後半部分です。3つ目のforループの中身が i行j列目の値を計算する部分になっています。

並列版にする

さて、これを並列版にしましょう。並列処理を書くときは、計算が並行で走っても結果が互いに影響しないことが大事です。

行列の掛け算の場合は各要素が独立に計算できるので、分け方は自明ですね。しかし各要素ごとに別の Goroutine に計算させると、生成される Goroutine の数は行列サイズの2乗となります。今回はサイズ1000の行列に対する計算を予定しているので、その数は100万です。こいつらを(Goに)管理させるのはかわいそうなので、行ごとに Goroutine が走る仕様にします。

matrix.go
type Matrix [][]int

func computePart(i int, a, b, c *Matrix, ch chan int) {
    ac := len((*a)[0])
    bc := len((*b)[0])
    for j := 0; j < bc; j++ {
        part := 0
        for k := 0; k < ac; k++ {
            part += (*a)[i][k] * (*b)[k][j]
        }
        (*c)[i][j] = part
    }
    ch <- 1
}

func mulConcurrent(a, b *Matrix) Matrix {
    ar := len(*a)
    ac := len((*a)[0])
    br := len(*b)
    bc := len((*b)[0])
    if ac != br {
        panic("wrong matrix type")
    }
    c := make(Matrix, ar)
    for i := 0 ; i < ar ; i++ {
        c[i] = make([]int, bc)
    }

    ch := make(chan int)

    // それぞれの行を並列処理させる
    for i := 0 ; i < ar ; i++ {
        go computePart(i, a, b, &c, ch)
    }

    // 終わるまで待つ
    for i := 0 ; i < ar ; i++ {
        <- ch
    }
    return c
}

前半部分は直列版と同じです。行を計算する部分をまるごと別の関数に切り出しました。

実験

実験に用いたコードの全文は ここから読めます。

プログラムは 2つのコマンドライン引数を受け取ります。第一引数は行列のサイズ、第二引数はGoroutineを同時に動かす数(並列度)です。並列度0の場合は、直列版の掛け算を使うようにします。

引数を受け取り、行列を生成して mul または mulConcurrent に投げるまでのコードを抜粋します。

matrix.go
func main() {
    n, _ := strconv.Atoi(os.Args[1])
    k, _ := strconv.Atoi(os.Args[2])
    runtime.GOMAXPROCS(k)

    a := make(Matrix, n)
    b := make(Matrix, n)
    for i := 0 ; i < n ; i++ {
        a[i] = make([]int, n)
        b[i] = make([]int, n)
        for j := 0 ; j < n ; j++ {
            a[i][j] = rand.Intn(100)
            b[i][j] = rand.Intn(100)
        }
    }
    var c Matrix
    if k == 0 {
        c = mul(&a, &b)
    } else {
        c = mulConcurrent(&a, &b)
    }
    fmt.Println(c[0][0])
}

実験結果(1)

手元のマシンで上述の matrix.go をコンパイルし、並列度を変えて走らせました。プログラムが終わるまでにかかった時間を計測しました。環境は以下。

key value
OS OSX 10.10.5
CPU 2 GHz Intel Core i7
メモリ 8 GB 1600 MHz DDR3
Go 1.5.1
--- ---
n 1000(固定)
k 0,1,2,...,8

そして以下結果です。k=6までは処理時間が改善していますが、それ以降は逆に効率が落ちてますね。最大で3倍強の高速化を達成しました。

matrix0.png

実験結果(2)

しかし、計算に7秒もかかっているようだとコンテストでは致命的です。今度はより多くのコアを擁するマシンで動かしてみます。EC2のc4.8xlargeを使ってみました。結果は以下の通りです。

matrix1.png

爆速ですね!!Google Code Jamで提出の制限時間に間に合うか微妙な時、切り札的な使い方ができそうです。

まとめ

Go言語を使うとこのように労力をかけずに並列処理を実現できます。この知識を活かして来年のGoogle Code Jamで嘘解法を無理やり通し、Googleに喧嘩を売りつつ忠誠心を示しましょう。
今回は行列の掛け算を題材に並列処理の実装例を示してみました。コード中に(ここはこのような書き方のほうが良いよ)等ありましたらコメントor編集リクエストで教えていただけるとありがたいです!

47
40
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
47
40