やってみたいこと
Bool値を100万個保持したいなどといったとき、ビット配列のデータ構造にすれば例えば128 KiBに1,048,576個の1/0値を保存することができます。ただ普通は計算機が整数番地でアクセスできる最小単位はバイトなので、ビット単位での読み書きには複数のビット命令が必要になります。
呼称や言語/ライブラリでのサポート状況は以下から:
DP(動的計画法)等では、ビット配列のi番目が立っていれば整数配列のi番目に1を足すといった(expand操作とでも呼びましょうか)が出てくるかもしれません。Mask付きのSIMD命令で上手く書けると気持ちいですね。ビット配列の側をSIMDのmaskとして利用してしまおうという魂胆です。
AVX-512の場合
Maskレジスタ(SVEではpredicateと呼びます)のロード/ストアですが、AVX-512ならkmovb/w/d/qの命令で普通の8/16/32/64-bit整数とほぼ同じ感覚でできます。SIMDのレーン数が64(=512/8)が最大というのはいいのですが、レジスタ幅と型の組み合わせによっては2-laneや4-laneになることもあるところ、マスクの型としては8-bitが最小のようです。256-bitのymmレジスタに32-bit整数をパックするときでも8-laneは使い切れるので、そこより小さい場合は気にする程のことでもないでしょう。
SVEの場合
以下、「たまたまこうやったら動いた」といった記述が多数あらわれます。
C/C++のレベルでは動作の正しさを保証しようがない話ばかりな点を予めご了承ください。
ARMのSVE (Scalable Vector Extension)で同じようなことをやろうと思ったとき、命令セットや組み込み関数のx86からの違いを意識しなくてはなりません。
ベクトル長が可変
Scalableというだけあって、SVEではベクトルレジスタの幅が$128n$-bitの可変長となっています。特定のCPUを仮定するのなら$n=4$のみを仮定してコードを書いてもいいのですが、今回はQEMUで走らせながら開発したので$n=1,2,4$ぐらいでは動くように心がけてみました。
ビットパターンが疎
Predicateレジスタの幅は$16n$-bitとなります。8-bit型をパックした型(svint8_t
など)に対しては全部のフラグを使うのですが、16-bit型をパックした型(svint16_t
やsvfloat16_t
)に対しては下半分ではなく偶数番目のビットを使います。SIMDレーン4つ分のビットパターンでいうと、$a,b,c,d\in{0,1}$に対して8-bitなら0b$dcba$だったものが16-bitになると0b0$d$0$c$0$b$0$a$に、32-bitでは0b000$d$000$c$000$b$000$a$のようになります。
サイズの異なる型変換のときも似たような作りになっており、小さい方の型は偶数番目ののレーンだけが変換の対象となります。どうも設計思想として「水平方向のデータ移動を極力避ける」というものがあるみたいです。どれだけ効くのかは知りませんが電力効率が理由のようです。
Predicateのload/store組み込み関数が見当たらない
命令セットにはpredicateレジスタ用にもLDR/STRという命令が確かにあるのですが、ACLE (ARM C Language Extension) for SVEのマニュアルには対応命令が見当たりませんでした。仕方がないので今回はtype-punningという書き方をしたらコンパイラが命令を出してくれたのでそのようにしました。uint16_t *ptr
というポインタがあったとき、svbool_t p = *(svbool_t *)ptr
のように書きます。
コード全体
擬似コードとしてはbitがもし[i]でアクセスできるなら
for(int i=0; i<64; i++) val[i] += bit[i];
みたいなものです。
まずは全部:
#include <cstdio>
#include <cstdint>
#include <arm_sve.h>
int main(){
volatile uint16_t bits[4] = {
0x3210,
0x7654,
0xba98,
0xfedc,
};
int16_t vals[64] = {0, };
const int inc = svcntb() / 16; // 128-bitなら4回転、512-bitなら1回する
const int vlen = svcnth(); // 16-bit型が何レーンあるか
for(int i=0, off=0; i<4; i+=inc, off+=2*vlen){
svbool_t mask = *(svbool_t *)(bits + i); // このままではint8_t用のpredicate
svbool_t pzero = svpfalse(); // 全部ゼロのpredicate
svbool_t plo = svzip1_b8(mask, pzero); // 下半分を伸長
svbool_t phi = svzip2_b8(mask, pzero); // 上半分を伸長
svbool_t pone = svptrue_b16(); // 偶数番bitが全部1
svint16_t vlo = svld1_s16(pone, &vals[off]);
svint16_t vhi = svld1_s16(pone, &vals[off + vlen]);
vlo = svadd_n_s16_m(plo, vlo, 1);
vhi = svadd_n_s16_m(phi, vhi, 1);
svst1_s16(pone, &vals[off], vlo);
svst1_s16(pone, &vals[off + vlen], vhi);
}
for(int i=0; i<64; i+=4){
for(int j=0; j<4; j++){
printf("%d", vals[i+j]);
}
puts("");
}
return 0;
// デバッグ用
puts("inp");
for(int i=0; i<64; i+=4){
for(int j=0; j<4; j++){
int k = i+j;
int w = k/16;
int b = k%16;
int v = (bits[w] & (1<<b)) ? 1 : 0;
printf("%d", v);
}
puts("");
}
return 0;
}
実行結果
$ g++ -march=armv8-a+sve -Wall -O2 bitmap.cpp
$ qemu-aarch64 -cpu max,sve512=on ./a.out
0000
1000
0100
1100
0010
1010
0110
1110
0001
1001
0101
1101
0011
1011
0111
1111
これは左右反転の2進数表示で0〜15まで増えていくだけの数列です。
環境ですがM1 macにUTMでUbuntuを入れてaptでg++とqemu-userを入れただけのものです。
512のところを256や128にしても同じ結果が得られました(得られるまでデバッグしました)。
makefileはこちらをベースとして借用しています。
GCC11の出力はこうなりました:
.L2:
ld1h z1.h, p1/z, [x1]
ld1h z0.h, p1/z, [x0]
add x5, x7, x3, lsl 1
ldr p0, [x5]
zip1 p3.b, p0.b, p2.b
add z1.h, p3/m, z1.h, z2.h
st1h z1.h, p1, [x1]
zip2 p0.b, p0.b, p2.b
add w2, w2, w8
add z0.h, p0/m, z0.h, z2.h
st1h z0.h, p1, [x0]
add x3, x3, x6
add x1, x1, x4
add x0, x0, x4
cmp w2, 3
ble .L2
下手に人が書いたソースなんかより綺麗なんじゃないでしょうか?
解説
volatile uint16_t bits[4] = {
0x3210,
0x7654,
0xba98,
0xfedc,
};
int16_t vals[64] = {0, };
64-bit分のビットパターンを即値で埋めました。配列や文字列は下位から上位の順で書けるのに、即値は何で上位から下位なのでしょう?なお、volatileを付けておかないとtype-punningによる読み込みを「未使用」と判断するのかg++は初期化を飛ばしてしまいます。下の方に(普通のC言語で)デバッグコードが付属しているの、要するにこの配列の内容を表示するコードが存在すると上の方のSVEのコードも動作するようになったという経緯の名残りです。
結果用のvalsはshortを64語で0初期化しておきます。
const int inc = svcntb() / 16; // 128-bitなら4回転、512-bitなら1回する
const int vlen = svcnth(); // 16-bit型が何レーンあるか
for(int i=0, off=0; i<4; i+=inc, off+=2*vlen){
読み込むのは反復あたり16$n$-bit、書き出しがshorが16$n$個なので32$n$ byteです。比例関係なので本当は2つも呼ばなくてもいいのですが、中途たくさん間違えた名残りです。
svbool_t mask = *(svbool_t *)(bits + i); // このままではint8_t用のpredicate
svbool_t pzero = svpfalse(); // 全部ゼロのpredicate
svbool_t plo = svzip1_b8(mask, pzero); // 下半分を伸長
svbool_t phi = svzip2_b8(mask, pzero); // 上半分を伸長
ここがメインディッシュとなります。先ずtype-punningでビットパターンを読み込みます。
このままでは8-bit型用のpredicateです。これの下半分と上半分を取り出して引き伸ばしてやると16-bit型用のpredicateになります。ZIP1/ZIP2という命令でこれは実現できます。どんな命令かはこちらから:
ただこれ、普通のベクトル型に対してはACLEのマニュアルに記述があるのですが、predicateに対しては見当たりませんでした。clang++のヘッダにはあったのでこう書いてしまったらg++でも命令が出たのでこれでいいことにしました。
svbool_t pone = svptrue_b16(); // 偶数番bitが全部1
svint16_t vlo = svld1_s16(pone, &vals[off]);
svint16_t vhi = svld1_s16(pone, &vals[off + vlen]);
vlo = svadd_n_s16_m(plo, vlo, 1);
vhi = svadd_n_s16_m(phi, vhi, 1);
svst1_s16(pone, &vals[off], vlo);
svst1_s16(pone, &vals[off + vlen], vhi);
}
これは単にレジスタ2本にデータを読んで、マスク付き加算を行い結果を書き出しています。
svadd_n_s16_m
なのですが、ここの_n
は最後のオペランドがスカラーであること、_m
はpredicateでマスクされていたら元の値を保持するというものです(代わりに_z
だとゼロに、_x
だとどっちでもいいとなるみたいです)。
感想
いかがでしたか?
思ったように命令が出てくれないときはなかなかもどかしい思いをしますね。
慣れている方はいっそのことアセンブリ言語で書いてしまった方が早いかもしれません。