TL;DR
- Go言語で行列の積を並列計算する実装例を示した(たぶんよりよい書き方あるので誰か教えて)
- ローカルのマシンとAWSのc4.8xlargeで実験すると、それなりに爆速になった
まえがき
競技プログラミングにおいて、行列累乗の計算が要求されることがよくあります。典型的な例は、複数の項の漸化式を行列の形に直すものでしょう。
行列の累乗の計算量は n
を行列のサイズ、P
を乗数とすると、繰り返し二乗法を用いることにより O(n^3logP)
です。これは、行列同士の掛け算の部分が O(n^3)
でボトルネックになっています。今回はこれを Goを用いて並列処理します。
Goの並列処理の仕組み
一言でまとめると、Goroutine
(ゴルーチンと読む)という存在が複数のスレッド上で走っていて、そいつらは channel
というものを介して相互に値をやり取りできます。
Goroutine
を生成するには go
構文を用いて以下のように書きます。
go func() {
// (何かしらの処理)
}()
定義してある関数を呼び出してもOKです。
func someFunction(somearg int) {
// (何かしらの処理)
}
go someFunction(114514)
これで // (何かしらの処理)
が別の Goroutine
で走ります。これらは処理をブロックせずに、Goroutine
を呼んだ側はそのまま走り続けます。
また、channel
を使うと以下のように値を送受信できます。
ch := make(chan int)
// 値を投げる
go func() {
ch <- 114514
}()
// 値をもらう
val := <- ch
// val = 114514
channel
から値をもらう方はデータが来るまで、値を投げる方は受け側でもらう準備ができるまで、処理をブロックします。(channelのバッファが0の場合) これを利用すると go
で呼び出した Goroutine
たちがすべて終わるまで待つことができます。
func someFunction(i int, ch chan int) {
// (何かしらの処理)
// 処理が終わったことを channel に物を詰めることで教える
ch <- 1
}
func main() {
ch := make(chan int)
// Goroutineを10個生成し、並列処理
for i := 0 ; i < 10 ; i ++ {
go someFunction(i, ch)
}
// Goroutineがすべて終わるのを待つ
for i := 0 ; i < 10 ; i ++ {
<- ch
}
}
行列の掛け算の実装
直列版を書く
まずは愚直に、2次元スライスで表された行列を掛け合わせる関数を書きましょう。
type Matrix [][]int
func mul(a, b *Matrix) Matrix {
ar := len(*a)
ac := len((*a)[0])
br := len(*b)
bc := len((*b)[0])
// 縦横のサイズが合わない場合
if ac != br {
panic("wrong matrix type")
}
// ここが O(n^3) になってる
c := make(Matrix, ar)
for i := 0 ; i < ar ; i++ {
c[i] = make([]int, bc)
for j := 0 ; j < bc ; j++ {
for k := 0 ; k < ac ; k++ {
c[i][j] += (*a)[i][k] * (*b)[k][j]
}
}
}
return c
}
処理の本体は後半部分です。3つ目のforループの中身が i行j列目の値を計算する部分になっています。
並列版にする
さて、これを並列版にしましょう。並列処理を書くときは、計算が並行で走っても結果が互いに影響しないことが大事です。
行列の掛け算の場合は各要素が独立に計算できるので、分け方は自明ですね。しかし各要素ごとに別の Goroutine
に計算させると、生成される Goroutine
の数は行列サイズの2乗となります。今回はサイズ1000の行列に対する計算を予定しているので、その数は100万です。こいつらを(Goに)管理させるのはかわいそうなので、行ごとに Goroutine
が走る仕様にします。
type Matrix [][]int
func computePart(i int, a, b, c *Matrix, ch chan int) {
ac := len((*a)[0])
bc := len((*b)[0])
for j := 0; j < bc; j++ {
part := 0
for k := 0; k < ac; k++ {
part += (*a)[i][k] * (*b)[k][j]
}
(*c)[i][j] = part
}
ch <- 1
}
func mulConcurrent(a, b *Matrix) Matrix {
ar := len(*a)
ac := len((*a)[0])
br := len(*b)
bc := len((*b)[0])
if ac != br {
panic("wrong matrix type")
}
c := make(Matrix, ar)
for i := 0 ; i < ar ; i++ {
c[i] = make([]int, bc)
}
ch := make(chan int)
// それぞれの行を並列処理させる
for i := 0 ; i < ar ; i++ {
go computePart(i, a, b, &c, ch)
}
// 終わるまで待つ
for i := 0 ; i < ar ; i++ {
<- ch
}
return c
}
前半部分は直列版と同じです。行を計算する部分をまるごと別の関数に切り出しました。
実験
実験に用いたコードの全文は ここから読めます。
プログラムは 2つのコマンドライン引数を受け取ります。第一引数は行列のサイズ、第二引数はGoroutineを同時に動かす数(並列度)です。並列度0の場合は、直列版の掛け算を使うようにします。
引数を受け取り、行列を生成して mul
または mulConcurrent
に投げるまでのコードを抜粋します。
func main() {
n, _ := strconv.Atoi(os.Args[1])
k, _ := strconv.Atoi(os.Args[2])
runtime.GOMAXPROCS(k)
a := make(Matrix, n)
b := make(Matrix, n)
for i := 0 ; i < n ; i++ {
a[i] = make([]int, n)
b[i] = make([]int, n)
for j := 0 ; j < n ; j++ {
a[i][j] = rand.Intn(100)
b[i][j] = rand.Intn(100)
}
}
var c Matrix
if k == 0 {
c = mul(&a, &b)
} else {
c = mulConcurrent(&a, &b)
}
fmt.Println(c[0][0])
}
実験結果(1)
手元のマシンで上述の matrix.go
をコンパイルし、並列度を変えて走らせました。プログラムが終わるまでにかかった時間を計測しました。環境は以下。
key | value |
---|---|
OS | OSX 10.10.5 |
CPU | 2 GHz Intel Core i7 |
メモリ | 8 GB 1600 MHz DDR3 |
Go | 1.5.1 |
--- | --- |
n | 1000(固定) |
k | 0,1,2,...,8 |
そして以下結果です。k=6までは処理時間が改善していますが、それ以降は逆に効率が落ちてますね。最大で3倍強の高速化を達成しました。
実験結果(2)
しかし、計算に7秒もかかっているようだとコンテストでは致命的です。今度はより多くのコアを擁するマシンで動かしてみます。EC2のc4.8xlargeを使ってみました。結果は以下の通りです。
爆速ですね!!Google Code Jamで提出の制限時間に間に合うか微妙な時、切り札的な使い方ができそうです。
まとめ
Go言語を使うとこのように労力をかけずに並列処理を実現できます。この知識を活かして来年のGoogle Code Jamで嘘解法を無理やり通し、Googleに喧嘩を売りつつ忠誠心を示しましょう。
今回は行列の掛け算を題材に並列処理の実装例を示してみました。コード中に(ここはこのような書き方のほうが良いよ)等ありましたらコメントor編集リクエストで教えていただけるとありがたいです!