はじめに
何かの探索でstd::map
をキャッシュとして使いたい。さらにその探索を複数プロセスでやったときに、そのキャッシュを同期したい。具体的には、異なるプロセス間で、キャッシュをマージしたい。そのためには、
- 各プロセスで
std::map
をシリアライズする - シリアライズしたデータをMPIで送受信する
- 受け取ったデータを自分の
std::map
にマージする - 以上をバタフライ型の通信で全ノード分集める
ということをすれば良い。すでにシリアライズ可能なstd::mapは書いたので、あとは通信部分を書くだけ。
ソースは
においておく。
方針
通信する相手のランク番号がdest
だとして、そいつとデータを送受信するとき、こういう形であって欲しい。
serializable_map<std::string, int> m;
// (snip)
std::vector<char> mybuffer = m.serialize();
std::vector<char> recvbuffer;
SendrecvVector(dest, mybuffer, recvbuffer);
m.deserialize(recvbuffer);
つまり、
- 自分のデータをシリアライズして
mybuffer
に入れる - 相手のシリアライズされたデータの受け皿
recvbuffer
を用意する - それぞれを
MPI_Sendrecv
する - 受け取ったデータをマージする
というのが1ステップ。これをバタフライ通信で全プロセス分の情報が集まるまで繰り返せば良い。すでにstd::map
のシリアライズ/デシリアライズは実装してあり、デシリアライズがマージそのものなので、後は通信部分だけ書けばよいことになる。
バタフライ通信
バタフライ通信というのは、ペアごとに情報をやりとりして同期していく方法で、例えば8プロセスのときには
step:0 (0,1) (2,3) (4,5) (6,7)
step:1 (0,2) (1,3) (4,6) (5,7)
step:2 (0,4) (1,5) (2,6) (3,7)
という3ステップの通信をおこなう。この時、
- 最初に0番は1番から情報をもらう。保持しているデータは [0,1]
- 次に0番は2番から情報をもらう。2番は前のステップで3番の情報をもらっていたので、0番が持つことになるデータは[0,1,2,3]
- 最後に4番から情報をもらうと、この時点で4番は[4,5,6,7]の情報を持っているので、最終的に0番は[0,1,2,3,4,5,6,7]とすべての情報を得る
このようにペアリングをすることで、一回ごとに情報が倍々に増えていく。
16プロセスのときには以下の4ステップになる。
step:0 (0,1) (2,3) (4,5) (6,7) (8,9) (10,11) (12,13) (14,15)
step:1 (0,2) (1,3) (4,6) (5,7) (8,10) (9,11) (12,14) (13,15)
step:2 (0,4) (1,5) (2,6) (3,7) (8,12) (9,13) (10,14) (11,15)
step:3 (0,8) (1,9) (2,10) (3,11) (4,12) (5,13) (6,14) (7,15)
で、このバタフライ通信、何度も書いてるわりに、毎回「通信相手の番号どうやって求めるんだっけ?」と頭をひねってる気がするので真面目に書いておく1。
まず、一回ごとに通信する集団のグループが倍になっていく。つまり、最初は2グループ、次は4グループ、次は8グループ、といった具合。そして、そのグループを2つにわけ、番号が若いほうとそうでない方でペアを組ませれば良い。通信すべき相手との番号の差も、ステップごとに倍々になっていく。
8プロセスの場合、最初のステップは
[0,1] [2,3] [4,5] [6,7]
とわかれて、これがそのままペアになる
次はまず
[0,1,2,3] [4,5,6,7]
と2つのグループにわかれる。最初のグループ[0,1,2,3]
を2つにわけると[0,1]
と[2,3]
にわかれ、それぞれから「差が2になるペア」を作ると(0,2), (1,3)
のペアを得る。残りも同様。
最後は
[0,1,2,3,4,5,6,7]
と一つにグループになるから、それを前後でわけると[0,1,2,3]
と[4,5,6,7]
。それぞれから「差が4になるペア」を作ると(0,4), (1,5), (2,6), (3,7)
となる。これで完成。
これを一般的に書く。自分のプロセス番号がrank
だとして、ステップを0から数えることにして、s
ステップ目の場合、$2^{(s+1)}$の大きさのグループにわけるのだから、まずrank % (1<<(s+1))
する。そのグループの中で番号が半分以下であるかどうかは、その数と(1 << s)
を比べれば良い。自分がグループの若い方である場合は、相手は自分の番号に(1 << s)
だけ足したものであり、そうでなければ引いたものである。
以上をまとめると、rank
番号がs
ステップ目に通信すべき相手の番号dest
は
int dest = (rank % (1 << (s + 1)) < (1 << s)) ? rank + (1 << s) : rank - (1 << s);
となる2。
先の通信部分と合わせるとこんな感じになる。procs
はプロセス数。
for (int s = 0; (1 << s) < procs; s++) {
int dest = (rank % (1 << (s + 1)) < (1 << s)) ? rank + (1 << s) : rank - (1 << s);
std::vector<char> mybuffer = m.serialize();
std::vector<char> recvbuffer;
SendrecvVector(dest, mybuffer, recvbuffer);
m.deserialize(recvbuffer);
MPI_Barrier(MPI_COMM_WORLD);
}
バリアはなくても大丈夫だと思うけど念のため。
シリアライズデータの送受信
シリアライズしたデータをstd::vector<char>
に格納しているので、それを送受信する必要がある。これは単純に
- まず送受信するサイズを送る
- 次にデータの実体を送る
と二段階の通信をすれば良い。そのまま実装するとこんな感じ。
template<class T>
void
SendrecvVector(int dest_rank, std::vector<T> &send_buffer, std::vector<T> &recv_buffer) {
int recv_size = 0;
int send_size = send_buffer.size();
MPI_Status st;
MPI_Sendrecv(&send_size, 1, MPI_INT, dest_rank, 0, &recv_size, 1, MPI_INT, dest_rank, 0, MPI_COMM_WORLD, &st);
recv_buffer.resize(recv_size);
MPI_Sendrecv(send_buffer.data(), send_size * sizeof(T), MPI_BYTE, dest_rank, 0, recv_buffer.data(), recv_size * sizeof(T), MPI_BYTE, dest_rank, 0, MPI_COMM_WORLD, &st);
}
テンプレートになっているのは、デバッグのときにstd::vector<int>
を送ったりしてたから。以上ですべて実装完了。
テスト
テストのために、プロセスごとに異なるハッシュを持っておき、それらが全部マージされたかどうかを調べる。まず、初期値はこんな感じにしておく。
serializable_map<std::string, int> m;
std::mt19937 mt(rank + 1);
std::uniform_int_distribution<int> ud(0, 10);
for (int i = 0; i < 3; i++) {
std::stringstream ss;
int value = ud(mt);
ss << "key" << std::setw(2) << std::setfill('0') << value;
m[ss.str()] = value;
}
4プロセスで実行した結果はこんな感じ。
$ mpic++ -std=c++11 main.cpp -o a.out
$ mpirun -np 4 ./a.out
before
rank = 0
key05:5
key08:8
key09:9
rank = 1
key06:6
key08:8
rank = 2
key08:8
key09:9
key10:10
rank = 3
key05:5
key07:7
key10:10
after
rank = 0
key05:5
key06:6
key07:7
key08:8
key09:9
key10:10
rank = 1
key05:5
key06:6
key07:7
key08:8
key09:9
key10:10
rank = 2
key05:5
key06:6
key07:7
key08:8
key09:9
key10:10
rank = 3
key05:5
key06:6
key07:7
key08:8
key09:9
key10:10
before
で4つのプロセスがバラバラのハッシュを持っていたのが、通信後はすべてのハッシュをマージした結果を共有するようになった。
まとめ
バタフライ演算を使って、複数のプロセスでハッシュのマージをしてみた。もっと苦労するかと思ったら、意外にあっさり組めた。シリアライズ可能なstd::map
さえあれば、20行足らずでできてしまう。
これでなんかの木の探索を複数プロセスでやる場合、プロセス間でキャッシュの共有ができるはず。もちろんずっとやってるとキャッシュがでかくなりすぎるから、実戦投入するには「ある程度太ったら古い情報を消す」みたいな工夫はいると思う。