前回の記事
AIアクセラレータ・IMAXの紹介 ~ (4) SIMD化とループ効率化
行列積のIMAX実装
AI分野において、行列積は馴染みのある演算ではないでしょうか。これは、パーセプトロンからなるレイヤの計算を行列で表せるためです。GoogleのTPUやNVIDIAのTensor Coreも、このような行列積を効率よく演算できる仕組みになっています。
IMAXでは、単純な行列積ではTPU等にスループットでは勝てませんが、その分汎用性はあります。「プログラム可能」というのは、正にそういった意味です。しかし、AI分野においてIMAXを使うためには、IMAXで行列積が実装可能でなければなりません。今回は、その行列積の例題を紹介しながら、IMAXでどうやって行列積を実装しているか、語っていきます。仕様的に新しいものの紹介はありませんが、どのような実装をしているのか、見ていきましょう。
そもそも行列積とは
行列積は、行列の掛け算です。ただ、行列を要素ごとに掛けるのではなく(これはアダマール積と言います)、このような規則で掛けていきます。n x mサイズの行列$A$とm x lサイズの行列$ B $の行列積$ AB $は次のような式で表すことができます。
AB_{ik} = \sum_{j=1}^{m}{A_{ij}B_{jk}}
式で表すとこのようになりますが、これを普通のプログラミング言語で表すと三重ループになります。単純に考えたら時間計算量は$ O(n・m・l) $です。しかし、行列積の演算は、$AB_{ik}$が表しているように、他の要素に対しては並列演算が可能です。i及びkは以前の演算結果を要しないので、スレッドレベルで並列実行して問題ありません。
OpenMPやCUDAなどでの実装は、スレッドレベルでの並列化を実装している場合がほとんどです。もちろん、中のループで行っている演算は単純規則化されているため、同時並列実行可能なコアの多いGPUが速いわけです。しかし、実はもっと高速化できる方法があります。ハードウェアレベルでの改造は必要になりますがね、、、
何故AIアクセラレータは行列積を爆速で行えるのか
CPUもCPUもそうですが、スレッドレベルでの並列化においても、上式の演算は一旦掛け算を行った値をどこかに保持し、後から取り出して足していきます。このような作業を繰り返しているため、(大体ですが)全コアで並列実行可能な規模でもループ分の時間を要することになります。場合によっては、一旦値をメモリに書き戻すなどの操作を行わないとならないかもしれません。
これ以上に高速化・効率化する案として、演算器を空間上に拡張する方法があります。一つの複雑なプログラムを実行できる「コア」ではなく、空間上に演算器を広げ、工場のコンベアベルトのように周辺の演算器にデータを流していく方法です。これが現代のAIアクセラレータの基本的な考え方になります。
TPUは、正にこのような演算器の配置を行い、行列積演算において高い性能を達成しています。しかし、このように演算器を「配置」しただけでは行列積以外できないので、この演算器の命令を「書き換える」ことを考えました。これがCGRAなのです。IMAXも例外ではなく、このような構造を上手く活用し行列積をGPUより効率よく行えるCGRAの一つです。次では、IMAXでどのような行列積を実現しているのか、見ていきます。
行列積の例
マニュアルでは、Column-Majorな行列の表現方式を使用しています。Row-Majorに慣れている方は多少読みづらいかもしれません。
//EMAX5A begin mm mapdist=0
for (CHIP=0; CHIP<NCHIP; CHIP++) { /* will be parallelized by multi-chip (M/#chip) */
for (INIT1=1,LOOP1=RMGRP,rofs=(0-L*4)<<32|((0-M2*4)&0xffffffff); LOOP1--; INIT1=0) { /* stage#0 *//* mapped to FOR() on BR[63][1][0] */
for (INIT0=1,LOOP0=M2/W,cofs=(0-W*4)<<32|((0-W*4)&0xffffffff); LOOP0--; INIT0=0) { /* stage#0 *//* mapped to FOR() on BR[63][0][0] */
// Step1. アドレス計算
exe(OP_ADD, &cofs, INIT0?cofs:cofs, EXP_H3210, (W*4)<<32|(W*4), EXP_H3210, 0LL, EXP_H3210, OP_AND, 0xffffffffffffffffLL, OP_NOP, 0LL);
exe(OP_ADD, &rofs, rofs, EXP_H3210, INIT0?(L*4)<<32|(M2*4):0, EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#0 */
exe(OP_ADD, &oofs, rofs, EXP_H3210, cofs, EXP_H3210, 0, EXP_H3210, OP_AND, 0xffffffff, OP_NOP, 0LL); /* stage#1 */
// Step2. データロード及び積和演算
mop(OP_LDWR, 1, &BR[1][0][1], (Ull)b000, (Ull)cofs, MSK_W1, (Ull)b00, M2, 0, 0, (Ull)NULL, M2); /* stage#1 */
mop(OP_LDWR, 1, &BR[1][0][0], (Ull)b001, (Ull)cofs, MSK_W1, (Ull)b00, M2, 0, 0, (Ull)NULL, M2); /* stage#1 */
mop(OP_LDWR, 1, &BR[1][1][1], (Ull)b002, (Ull)cofs, MSK_W1, (Ull)b00, M2, 0, 0, (Ull)NULL, M2); /* stage#1 */
mop(OP_LDWR, 1, &BR[1][1][0], (Ull)b003, (Ull)cofs, MSK_W1, (Ull)b00, M2, 0, 0, (Ull)NULL, M2); /* stage#1 2KB */
mop(OP_LDWR, 1, &BR[1][2][1], (Ull)a00[CHIP], (Ull)rofs, MSK_W1, (Ull)a0[CHIP], L*RMGRP, 0, 0, (Ull)NULL, L*RMGRP); /* stage#1 16KB */
exe(OP_FML, &AR[2][0], BR[1][0][1], EXP_H3210, BR[1][2][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#2 */
exe(OP_FML, &AR[2][1], BR[1][0][0], EXP_H3210, BR[1][2][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#2 */
exe(OP_FML, &AR[2][2], BR[1][1][1], EXP_H3210, BR[1][2][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#2 */
exe(OP_FML, &AR[2][3], BR[1][1][0], EXP_H3210, BR[1][2][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#2 */
mop(OP_LDWR, 1, &BR[2][0][1], (Ull)b010, (Ull)cofs, MSK_W1, (Ull)b01, M2, 0, 0, (Ull)NULL, M2); /* stage#2 */
mop(OP_LDWR, 1, &BR[2][0][0], (Ull)b011, (Ull)cofs, MSK_W1, (Ull)b01, M2, 0, 0, (Ull)NULL, M2); /* stage#2 */
mop(OP_LDWR, 1, &BR[2][1][1], (Ull)b012, (Ull)cofs, MSK_W1, (Ull)b01, M2, 0, 0, (Ull)NULL, M2); /* stage#2 */
mop(OP_LDWR, 1, &BR[2][1][0], (Ull)b013, (Ull)cofs, MSK_W1, (Ull)b01, M2, 0, 0, (Ull)NULL, M2); /* stage#2 2KB */
mop(OP_LDWR, 1, &BR[2][2][1], (Ull)a01[CHIP], (Ull)rofs, MSK_W1, (Ull)a0[CHIP], L*RMGRP, 0, 0, (Ull)NULL, L*RMGRP); /* stage#2 16KB */
:
exe(OP_FMA, &AR[60][0], AR[59][0], EXP_H3210, BR[59][2][1], EXP_H3210, BR[59][0][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#60 */
exe(OP_FMA, &AR[60][1], AR[59][1], EXP_H3210, BR[59][2][1], EXP_H3210, BR[59][0][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#60 */
exe(OP_FMA, &AR[60][2], AR[59][2], EXP_H3210, BR[59][2][1], EXP_H3210, BR[59][1][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#60 */
exe(OP_FMA, &AR[60][3], AR[59][3], EXP_H3210, BR[59][2][1], EXP_H3210, BR[59][1][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#60 */
mop(OP_LDWR, 1, &BR[60][0][1], (Ull)b590, (Ull)cofs, MSK_W1, (Ull)b59, M2, 0, 0, (Ull)NULL, M2); /* stage#60 */
mop(OP_LDWR, 1, &BR[60][0][0], (Ull)b591, (Ull)cofs, MSK_W1, (Ull)b59, M2, 0, 0, (Ull)NULL, M2); /* stage#60 */
mop(OP_LDWR, 1, &BR[60][1][1], (Ull)b592, (Ull)cofs, MSK_W1, (Ull)b59, M2, 0, 0, (Ull)NULL, M2); /* stage#60 */
mop(OP_LDWR, 1, &BR[60][1][0], (Ull)b593, (Ull)cofs, MSK_W1, (Ull)b59, M2, 0, 0, (Ull)NULL, M2); /* stage#60 */
mop(OP_LDWR, 1, &BR[60][2][1], (Ull)a59[CHIP], (Ull)rofs, MSK_W1, (Ull)a0[CHIP], L*RMGRP, 0, 0, (Ull)NULL, L*RMGRP); /* stage#60 */
exe(OP_FMA, &AR[61][0], AR[60][0], EXP_H3210, BR[60][2][1], EXP_H3210, BR[60][0][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#61 */
exe(OP_FMA, &AR[61][1], AR[60][1], EXP_H3210, BR[60][2][1], EXP_H3210, BR[60][0][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#61 */
exe(OP_FMA, &AR[61][2], AR[60][2], EXP_H3210, BR[60][2][1], EXP_H3210, BR[60][1][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#61 */
exe(OP_FMA, &AR[61][3], AR[60][3], EXP_H3210, BR[60][2][1], EXP_H3210, BR[60][1][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#61 */
mop(OP_LDWR, 1, &BR[62][0][1], (Ull)c00[CHIP], (Ull)oofs, MSK_W0, (Ull)c0[CHIP], M2*RMGRP, 0, 1, (Ull)NULL, M2*RMGRP); /* stage#62 */
// Step3. 結果格納
mop(OP_LDWR, 1, &BR[62][1][1], (Ull)c01[CHIP], (Ull)oofs, MSK_W0, (Ull)c0[CHIP], M2*RMGRP, 0, 1, (Ull)NULL, M2*RMGRP); /* stage#62 */
mop(OP_LDWR, 1, &BR[62][2][1], (Ull)c02[CHIP], (Ull)oofs, MSK_W0, (Ull)c0[CHIP], M2*RMGRP, 0, 1, (Ull)NULL, M2*RMGRP); /* stage#62 */
mop(OP_LDWR, 1, &BR[62][3][1], (Ull)c03[CHIP], (Ull)oofs, MSK_W0, (Ull)c0[CHIP], M2*RMGRP, 0, 1, (Ull)NULL, M2*RMGRP); /* stage#62 */
exe(OP_FAD, &AR[62][0], AR[61][0], EXP_H3210, BR[62][0][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#62 */
exe(OP_FAD, &AR[62][1], AR[61][1], EXP_H3210, BR[62][1][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#62 */
exe(OP_FAD, &AR[62][2], AR[61][2], EXP_H3210, BR[62][2][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#62 */
exe(OP_FAD, &AR[62][3], AR[61][3], EXP_H3210, BR[62][3][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#62 */
mop(OP_STWR, 1, &AR[62][0], (Ull)oofs, (Ull)c00[CHIP], MSK_D0, (Ull)c0[CHIP], M2*RMGRP, 0, 1, (Ull)NULL, M2*RMGRP); /* stage#62 */
mop(OP_STWR, 1, &AR[62][1], (Ull)oofs, (Ull)c01[CHIP], MSK_D0, (Ull)c0[CHIP], M2*RMGRP, 0, 1, (Ull)NULL, M2*RMGRP); /* stage#62 */
mop(OP_STWR, 1, &AR[62][2], (Ull)oofs, (Ull)c02[CHIP], MSK_D0, (Ull)c0[CHIP], M2*RMGRP, 0, 1, (Ull)NULL, M2*RMGRP); /* stage#62 */
mop(OP_STWR, 1, &AR[62][3], (Ull)oofs, (Ull)c03[CHIP], MSK_D0, (Ull)c0[CHIP], M2*RMGRP, 0, 1, (Ull)NULL, M2*RMGRP); /* stage#62 */
}
}
}
//EMAX5A end
演算フロー
Step1. アドレス計算
ロードするデータのインデックスを指定します。ループ文にてこの設定はできないので、最初の段で設定を行います。
Step2. データロード及び積和演算
演算対象の行列から先ほどのインデックスのデータをロードし、積和演算を繰り返します。
Step3. 値の格納
最終的な値を格納するところです。1サイクルでは行列積の演算は終わらないので、前回の値をロードし、格納する仕組みになっています。
解説
ループ以前の前処理
前処理として、行列のデータを細かく分割します。アドレスのコピーは要らず、アドレス単位で分割したら大丈夫です。その分割したアドレスを示すのがb000``b001
のようなものです。このアドレスの元、オフセットで行及び列を指定します。
ループ
一見ループだけ見ればこれでオフセットの指定ができるように見えますが、使用中のユニット数を把握しやすくするために実際のオフセットはループ中の処理でやっています。なお、INIT0?cofs:cofs
ですが、コンパイラ上で特殊文法として扱いながらエミュレーションの際にはC言語のコードとして動かすための工夫です。このように記述した場合、CGRA自ら初期化を行う設定になります。
演算
分割された指定のアドレスからオフセットを出し、ロードし、積和演算を行うが全てです。axx[CHIP]
は行列aの列の指定、bxxx
は行列bの行指定になっています。そこから、rofs
及びcofs
を用いてロードしたいデータを指定し、積和演算を行います。ちなみに、このオフセットはバイト単位となっているので、一要素のビット数を確認したから指定しましょう。
更なる効率化
もちろんこの行列積も前回の記事のように、SIMD化を行い2倍のスループットを出すことはできます。しかし、多少特殊なアドレスの並びになってしまうので、使用するアプリによってはあまり向いてないかもしれません。
おわりに
今回は、AI演算において最も使用されるとされる行列積演算のIMAX実装について解説しました。しかし、簡単な行列積だけだとTPUの方が優れているので、プログラミング可能という点を活かして他のアプリケーションをどういうふうに効率化しているのか、見ていきたいと思います。ではまた。
次回の記事
AIアクセラレータ・IMAXの紹介 ~ (6) 畳み込み演算