追記 2022 03 22
CUDAについてくるthrustやCUBなどのライブラリでは、ここで説明するのとは異なるアルゴリズムでRadix Sortを実装しています。メジャーなアルゴリズムの中では最速だという理解でお願いします。
最近cudaでradix sortを組んだのでその仕組みと工夫ポイントの紹介をします。
長くなってしまうので何回かに分けて説明します。まずはRadix sortの仕組みから。
え? Quick sortじゃないの?
僕はそれほどソートアルゴリズムの歴史について詳しくありませんが、一般に最速と言われるアルゴリズムはQuick sortだと思います。CPUでの計算ならそれが最良の選択ですが、GPUの鬼の並列能力にかかる時、2つの数値の比較を逐一実行するQuick sortは効率が悪いのです。
Radix sortとは
この記事を見る方はRadix sortにしか興味がないと思うので、「Radix sortはBucket sortの進化系で〜」みたいな説明は省きます。簡単に特徴を上げておくと、
- 並列化しやすい。
- 比較がいらない。ifが減るのでgpu的には嬉しい。
- 数列の最大値・最小値が決まっていると高速化できる。
ソートの仕組み
Radix sortは何進数で説明するかが結構大事なのですが、どうせコンピュータにやらせるので2進数で説明します。適当に以下の数列を用意して、ストーリー立てて説明します。
001 | 010 | 011 | 101 | 011 | 110 | 111 | 000 | 001 | 100 | 101 | 011 |
---|
この数列は3ビット内に数字が収まっているので、並べ替えを3回行います。まずは最下位のビットだけを見て、前半にビットが0のものを、後半に1のものを持っていきます。
元の数列 | 001 | 010 | 011 | 101 | 011 | 110 | 111 | 000 | 001 | 100 | 101 | 011 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
1ステップ目 | 010 | 110 | 000 | 100 | 001 | 011 | 101 | 011 | 111 | 001 | 101 | 011 |
並べ替えに使うビットを1桁ずつずらしながら同じ作業を繰り返します。次は2ビット目が0のものを前半に移動します。
元の数列 | 001 | 010 | 011 | 101 | 011 | 110 | 111 | 000 | 001 | 100 | 101 | 011 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
1ステップ目 | 010 | 110 | 000 | 100 | 001 | 011 | 101 | 011 | 111 | 001 | 101 | 011 |
2ステップ目 | 000 | 100 | 001 | 101 | 001 | 101 | 010 | 110 | 011 | 011 | 111 | 011 |
そして3ビット目まで行うと完全にソートできます。
元の数列 | 001 | 010 | 011 | 101 | 011 | 110 | 111 | 000 | 001 | 100 | 101 | 011 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
1ステップ目 | 010 | 110 | 000 | 100 | 001 | 011 | 101 | 011 | 111 | 001 | 101 | 011 |
2ステップ目 | 000 | 100 | 001 | 101 | 001 | 101 | 010 | 110 | 011 | 011 | 111 | 011 |
3ステップ目 | 000 | 001 | 001 | 010 | 011 | 011 | 011 | 100 | 101 | 101 | 110 | 111 |
コードを組むときは、どこかのビットを基準に並べ替える関数fと、必要な最小桁から最大桁までfをループで呼び出す関数gを用意するといいでしょう。別に最大・最小値が何かを調べる関数があってもいいかも知れませんね。
数列の最大・最小値が分かっているとループ数が削減できて速いことは理解できると思いますが、どこが並列かに向いているかわからないし、比較使ってるじゃんと思うでしょう。次はこの辺を説明します。
並べ替えの方法
さっきのストーリーだと、どう考えても0と1の比較をしているように見えますが、実に賢い方法で並べ替えます。並べ替える際には、移動先のインデックスが必要です。2進数の場合限定ですが比較なしで、かつ余計なメモリを消費せずに移動先のインデックスが求まる方法があります。
まず、対象のビットだけを取り出します。説明の概念として取り出しただけであって、このビットを格納するメモリは必要ないです。便宜上bと名付けます。
001 | 010 | 011 | 101 | 011 | 010 | 111 | 000 | 001 | 100 | 101 | 110 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
b | 1 | 0 | 1 | 1 | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 0 |
次にこれらを反転させたビットを専用のバッファに保存します。ここはメモリ使います。これも便宜上eとします。
001 | 010 | 011 | 101 | 011 | 010 | 111 | 000 | 001 | 100 | 101 | 110 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
b | 1 | 0 | 1 | 1 | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 0 |
e | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | 0 | 1 |
ここが肝の処理になるのですが、反転したビットの累積和を取ります。ここで言う累積和は以下の定義に従います。
$$S_n = a_0 + a_1 + \cdots + a_{n-1}$$
読みやすさのため配列名をfとします。
001 | 010 | 011 | 101 | 011 | 010 | 111 | 000 | 001 | 100 | 101 | 110 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
b | 1 | 0 | 1 | 1 | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 0 |
e | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | 0 | 1 |
f | 0 | 0 | 1 | 1 | 1 | 1 | 2 | 2 | 3 | 3 | 4 | 4 |
説明上は累積和fと反転ビットeを分けて書いていますが、組み込む上ではビット取り出しbも反転eも元の数列から計算できるので、反転ビットの配列に上書きする形で累積和を取れば良いです。
ここで、最終要素の反転ビットと累積和を足したものをtotalFalsesとします。今の場合totalFalses = 4 + 1 = 5です。次にびっくり操作をします。おもむろに元のインデックスiを使ってt = i - f + totalFalsesを計算します。
001 | 010 | 011 | 101 | 011 | 010 | 111 | 000 | 001 | 100 | 101 | 110 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
b | 1 | 0 | 1 | 1 | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 0 |
e | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | 0 | 1 |
f | 0 | 0 | 1 | 1 | 1 | 1 | 2 | 2 | 3 | 3 | 4 | 4 |
t = i - f + totalFalses | 0-0+5 =5 |
1-0+5 =6 |
2-1+5 =6 |
3-1+5 =7 |
4-1+5 =8 |
5-1+5 =9 |
6-2+5 =9 |
7-2+5 =10 |
8-3+5 =10 |
9-3+5 =11 |
10-4+5 =11 |
11-4+5 =12 |
これで移動先のインデックスdはd = (b == 1) ? t : fでもとまります。
001 | 010 | 011 | 101 | 011 | 010 | 111 | 000 | 001 | 100 | 101 | 110 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
b | 1 | 0 | 1 | 1 | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 0 |
e | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 1 | 0 | 1 |
f | 0 | 0 | 1 | 1 | 1 | 1 | 2 | 2 | 3 | 3 | 4 | 4 |
t = i - f + totalFalses | 0-0+5 =5 |
1-0+5 =6 |
2-1+5 =6 |
3-1+5 =7 |
4-1+5 =8 |
5-1+5 =9 |
6-2+5 =9 |
7-2+5 =10 |
8-3+5 =10 |
9-3+5 =11 |
10-4+5 =11 |
11-4+5 =12 |
d | 5 | 0 | 6 | 7 | 8 | 1 | 9 | 2 | 10 | 3 | 11 | 4 |
dの通りに並べ替えると
001 | 010 | 011 | 101 | 011 | 010 | 111 | 000 | 001 | 100 | 101 | 110 |
---|---|---|---|---|---|---|---|---|---|---|---|
010 | 010 | 000 | 100 | 110 | 001 | 011 | 101 | 011 | 111 | 001 | 101 |
となります。感動。
各処理はgpuでどう記述するか。
GPUは大量のスレッドを持って並列処理ができますが、各スレッド間でのデータのやり取りはできません。通常はスレッドのインデックスと配列のインデックスを1対1対応させて、1要素あたり1スレッド割くようにします。この時、ビットの取り出し・ビット反転・配列の並べ替えは以下のように単純に記述できます。
//
// ビットの取り出し
//
// スレッドのインデックス。意味は次回やります。
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
// idx = (配列のインデックス)とするので、要素を取り出す時はこれだけでいい。
int elem = org_array[idx];
// ビット抽出も各スレッドが独立に動けばいい。
int fst_bit = (elem & 1);
//
// ビット反転
//
// ビット反転も普通にやればいい
int fst_e_bit = (int)(!(bool)fst_bit);
//
// 並べ替え
//
// 移動先のインデックスdが各スレッドで重複しなければこれで移動できる。
out_array[d] = ord_array[idx];
問題は累積和なのですが、長いので次回!