Go言語でバイトニックソート実装してみた

  • 18
    いいね
  • 1
    コメント
この記事は最終更新日から1年以上が経過しています。

こんばんは。
Go言語 Advent Calendar 7日目ばっちり遅刻しました。

普段はGo言語ではなくJavaScriptでWebGLを書いています。
何故突然Go言語Advent Calendarに登録したかというと自分を追い込みたかったからです。
なかなかGo言語を覚えるきっかけが見つからなかったのでこれがチャンス!と勢いで登録しました。
そして追い込みすぎた結果遅刻しました。

ほぼ初めてのGo言語です、お手柔らかにお願い致します。

何故バイトニックソートか

バイトニックソートは最良計算量と最悪計算量がともに$O(n\log(n)^2)$のソートアルゴリズムです。
クイックソートが平均計算量$O(n\log(n))$であることをを考えるとなんだか地味に見えるかもしれません。

ですが、バイトニックソートは並列計算ができるという驚異的な特徴を持っており、完全に並列化すると$O(\log(n)^2)$という速度を叩き出します。
これ、goroutineで書くとちょっと速いソートができるんじゃないかと思い書いてみました。

直列のバイトニックソート

実装はWikipediaのJavaのサンプルコードを参考にします。

package main

import (
    "fmt"
    "math/rand"
    "time"
)

func swap(data []float64, i, j int) {
    tmp := data[i]
    data[i] = data[j]
    data[j] = tmp
}

func f1(data []float64, p, q int) {
    d := 1 << uint(p-q)

    for i := range data {
        f2(data, i, p, d)
    }
}

func f2(data []float64, i, p, d int) {
    up := (i>>uint(p))&2 == 0
    j := i | d
    if i&d == 0 && (data[i] > data[j]) == up {
        swap(data, i, j)
    }
}

func sort(logn int, data []float64) []float64 {
    for i := 0; i < logn; i++ {
        for j := 0; j <= i; j++ {
            f1(data, i, j)
        }
    }
    return data
}

func test(data []float64) bool {
    end := len(data) - 1
    for i := 0; i < end; i++ {
        if data[i] > data[i+1] {
            return false
        }
    }
    return true
}

func createRandom(N int) []float64 {
    data := make([]float64, N)

    for i := range data {
        data[i] = float64(i) / float64(N)
    }

    for i := range data {
        swap(data, i, rand.Intn(N))
    }

    return data
}

func main() {
    logn := 24
    n := 1 << uint(logn)

    rand.Seed(time.Now().UnixNano())

    fmt.Printf("n: %d\n", n)
    fmt.Printf("logn: %d\n", logn)

    data := createRandom(n)

    t0 := time.Now().UnixNano()
    sort(logn, data)
    t1 := time.Now().UnixNano()

    fmt.Printf("t: %dms\n", (t1-t0)/1000000)

    if test(data) {
        fmt.Println("test => OK")
    } else {
        fmt.Println("test => NG")
    }
}

あぁ、ビット演算がところどころ出てきてちょっとややこしいですね。
普段JavaScriptばっかり書いてるので暗黙の型変換ができないのはなんだかもどかしいです。

これを実行すると次の結果が得られます。

n: 16777216
logn: 24
t: 20227ms
test => OK

ソートには成功しているようで、ソートにかかった時間はおおよそ20秒です。

goroutineを使ったバイトニックソート

関数f1に着目して、要素数だけループしているものをCPU数だけgoroutineに分割します。

package main

import (
    "fmt"
    "math"
    "math/rand"
    "runtime"
    "sync"
    "time"
)

func swap(data []float64, i, j int) {
    tmp := data[i]
    data[i] = data[j]
    data[j] = tmp
}

func f1(data []float64, p, q, block int) {
    d := 1 << uint(p-q)

    length := len(data)

    wg := &sync.WaitGroup{}

    begin := 0
    for begin < length {
        end := begin + block
        if end > length {
            end = length
        }

        wg.Add(1)
        go func(begin, end int) {
            for i := begin; i < end; i++ {
                f2(data, i, p, d)
            }
            wg.Done()
        }(begin, end)

        begin = end
    }

    wg.Wait()
}

func f2(data []float64, i, p, d int) {
    up := (i>>uint(p))&2 == 0
    j := i | d
    if i&d == 0 && (data[i] > data[j]) == up {
        swap(data, i, j)
    }
}

func sort(logn, block int, data []float64) []float64 {
    for i := 0; i < logn; i++ {
        for j := 0; j <= i; j++ {
            f1(data, i, j, block)
        }
    }
    return data
}

func createRandom(N int) []float64 {
    data := make([]float64, N)

    for i := range data {
        data[i] = float64(i) / float64(N)
    }

    for i := range data {
        swap(data, i, rand.Intn(N))
    }

    return data
}

func test(data []float64) bool {
    end := len(data) - 1
    for i := 0; i < end; i++ {
        if data[i] > data[i+1] {
            return false
        }
    }
    return true
}

func main() {
    logn := 24
    n := 1 << uint(logn)

    rand.Seed(time.Now().UnixNano())

    fmt.Printf("n: %d\n", n)
    fmt.Printf("logn: %d\n", logn)

    data := createRandom(n)
    cpus := runtime.NumCPU()

    //計算するブロック
    block := int(math.Ceil(float64(len(data)) / float64(cpus)))

    fmt.Printf("cpus: %d\n", cpus)
    fmt.Printf("block: %d\n", block)

    //つかうCPUを増やす
    runtime.GOMAXPROCS(runtime.NumCPU())

    t0 := time.Now().UnixNano()
    sort(logn, block, data)
    t1 := time.Now().UnixNano()

    fmt.Printf("t: %dms\n", (t1-t0)/1000000)

    if test(data) {
        fmt.Println("test => OK")
    } else {
        fmt.Println("test => NG")
    }
}

実行結果は以下のとおり、期待通り高速化できているようです。

n: 16777216
logn: 24
cpus: 8
block: 2097152
t: 6882ms
test => OK

おまけ

sort.Float64sでソートしてみると以下の結果になりました。

n: 16777216
logn: 24
t: 6856ms
test => OK

ううん、自前実装のほうが優位とは言えないですね、残念です。

Go言語を少し書いてみて

goroutineってなかなかおもしろいですね(小並
Visual Studio Codeでデバッグできたりもしますし、なかなか開発環境も優れていると思います。

WASMとか作れるようになると夢広がるなぁ

この投稿は Go その3 Advent Calendar 20157日目の記事です。