前回の記事
畳み込み演算について
畳み込み演算は、画像等の特徴量抽出によく用いられる演算です。畳み込み演算を用い、様々な画像フィルターを実現することができます。また、CNN(Convolutional Neural Netowrk)という畳み込み演算を用いたニューラルネットワークモデルは、今でも画像処理分野において広く使われています。この演算の実装可否は、AIアクセラレータとして使うにはとても重要な要素となり得ます。
具体的な演算
畳み込み演算は、連続した区間において次のような積分で定義されます。
(f * g)(t) = \int f(\tau) g(t - \tau)\, d\tau
$f$と$g$の畳み込み演算は、$f$に対し、$\tau$分ずれた$g$の値をかけて積分することを繰り返す演算となります。離散空間においては、次のように定義されます。
(f * g)(m) = \sum_n {f(n) \, g(m - n)}
もちろん、画像のような二次元空間でも似たような定義ができます。
(f * g)(m, n) = \sum_i \sum_j {f(i, j) \, g(m - i, n -j)}
このような演算アルゴリズムを踏まえて、IMAXでの実装を考えていきましょう。
畳み込み演算の並列処理
勘のいい方はすでに気づいていると思いますが、カーネルとデータのアダマール積をひたすらカーネルを移動させながら繰り返すという処理になっています。適当にCUDAのコード書いてみるとこうなりそうですね。
__global__ convolution(float *result, float *input, float *kernel) {
int gidx = blockDim.x * blockIdx.x + threadIdx.x;
int gidy = blockDim.y * blockIdx.y + threadIdx.y;
if (gidx < SIZE_X && gidy < SIZE_Y) {
for (int i = -KER_X/2; i < KER_X/2; i++) {
for (int j = -KER_Y/2; j < KER_Y/2; j++) {
result[gidy*SIZE_X + gidx] +=
input[(gidy+j)*SIZE_X + (gidx+i)]*kernel[j*KER_X + i];
}
}
}
}
大体のカーネルサイズは3x3とかなんでループ文なくした方が速くなりそうではありますが、概念説明なので適当に書いておきました。GPUで並列処理として実装するとなったらこうなりそうですね。もちろんOpenMPによるCPU並列化も同様の考え方で畳み込み演算を実装することになります。
IMAXではどう実装すればいいか?
たくさんのスレッドを同時実行できれば上のような考え方は有効ですが、果たしてIMAXではどうでしょう。仕様書の図を見てみましょう。CNNの3x3カーネルを想定した例題となります。
わかりづらいかもしれませんが、(a)の左が入力画像、真ん中がカーネル、右が出力となっています。IC
は入力チャンネル数、OC
は出力チャンネル数を意味します。
Ull CHIP; Ull LOOP1, LOOP0; Ull INIT1, INIT0;
Ull AR[64][4]; /* output of EX in each unit */
Ull BR[64][4][4]; /* output registers in each unit */
Ull r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15;
Ull r16, r17, r18, r19, r20, r21, r22, r23, r24, r25, r26, r27, r28, r29, r30, r31;
Ull cc0, cc1, cc2, cc3, ex0, ex1; Ull cofs, rofs, oofs;
for (top=1; top<M-1; top+=RMGRP) {
for (iset=0; iset<IC; iset+=IMAP) { /* accumulate multiple sets of IC */
Uint *ip0 = &in[(iset+0)*M*M]; /* top of input#0 */
Uint *it00 = ip0+(top-1)*M+1-1, *ip00 = ip0+(top-1)*M+1-1, *ip01 = ip0+(top-1)*M+1+0, *ip02 = ip0+(top-1)*M+1+1;
Uint *ip03 = ip0+(top+0)*M+1-1, *ip04 = ip0+(top+0)*M+1+0, *ip05 = ip0+(top+0)*M+1+1;
Uint *ip06 = ip0+(top+1)*M+1-1, *ip07 = ip0+(top+1)*M+1+0, *ip08 = ip0+(top+1)*M+1+1;
:
for (oc=0; oc<OC/NCHIP; oc+=W) { /* set output channel */
Uint *kp00[NCHIP], *kp01[NCHIP], *kp02[NCHIP], *kp03[NCHIP];
:
Uint *kp50[NCHIP], *kp51[NCHIP], *kp52[NCHIP], *kp53[NCHIP];
Uint *op0[NCHIP], *op1[NCHIP], *op2[NCHIP], *op3[NCHIP];
Uint *ot0[NCHIP], *ot1[NCHIP], *ot2[NCHIP], *ot3[NCHIP];
for (CHIP=0; CHIP<NCHIP; CHIP++) { /* output channels are parallelized by multi-chip (OC/#chip) */
Uint choc = CHIP*OC/NCHIP+oc;
kp00[CHIP] = ker+((choc+0)*IC+iset+0)*K*K;
kp01[CHIP] = ker+((choc+1)*IC+iset+0)*K*K;
kp02[CHIP] = ker+((choc+2)*IC+iset+0)*K*K;
kp03[CHIP] = ker+((choc+3)*IC+iset+0)*K*K;
:
op0[CHIP] = out1+(choc+0)*M*M+top*M+1;
op1[CHIP] = out1+(choc+1)*M*M+top*M+1;
op2[CHIP] = out1+(choc+2)*M*M+top*M+1;
op3[CHIP] = out1+(choc+3)*M*M+top*M+1;
ot0[CHIP] = out1+(choc+0)*M*M+top*M+0;
ot1[CHIP] = out1+(choc+1)*M*M+top*M+0;
ot2[CHIP] = out1+(choc+2)*M*M+top*M+0;
ot3[CHIP] = out1+(choc+3)*M*M+top*M+0;
}
//EMAX5A begin cnn mapdist=0
for (CHIP=0; CHIP<NCHIP; CHIP++) { /* output channels are parallelized by multi-chip (OC/#chip) */
for (INIT1=1,LOOP1=RMGRP,rofs=0-M*4; LOOP1--; INIT1=0) { /* stage#0 *//* mapped to FOR() on BR[63][1][0] */
for (INIT0=1,LOOP0=M-2,cofs=0-4; LOOP0--; INIT0=0) { /* stage#0 *//* mapped to FOR() on BR[63][0][0] */
exe(OP_ADD, &cofs, INIT0?cofs:cofs, EXP_H3210, 4, EXP_H3210, 0LL, EXP_H3210, OP_AND, 0x00000000ffffffffLL, OP_NOP, 0LL);/* stage#0 */
exe(OP_ADD, &rofs, rofs, EXP_H3210, INIT0?M*4:0, EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#0 */
exe(OP_ADD, &oofs, rofs, EXP_H3210, cofs, EXP_H3210, 0LL, EXP_H3210, OP_AND, 0x00000000ffffffffLL, OP_NOP, 0LL); /* stage#1 */
mop(OP_LDWR, 1, &BR[2][0][1], (Ull)kp00[CHIP], 0LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#2 */
mop(OP_LDWR, 1, &BR[2][0][0], (Ull)kp01[CHIP], 0LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#2 */
mop(OP_LDWR, 1, &BR[2][1][1], (Ull)kp02[CHIP], 0LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#2 */
mop(OP_LDWR, 1, &BR[2][1][0], (Ull)kp03[CHIP], 0LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#2 10KB */
mop(OP_LDWR, 1, &BR[2][2][1], (Ull)ip00, oofs, MSK_W0, (Ull)it00, M*(RMGRP+2), 0, 0, (Ull)NULL, M*(RMGRP+2)); /* stage#2 8KB */
/****in0*****/
exe(OP_FML, &AR[3][0], BR[2][2][1], EXP_H3210, BR[2][0][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#3 */
exe(OP_FML, &AR[3][1], BR[2][2][1], EXP_H3210, BR[2][0][0], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#3 */
exe(OP_FML, &AR[3][2], BR[2][2][1], EXP_H3210, BR[2][1][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#3 */
exe(OP_FML, &AR[3][3], BR[2][2][1], EXP_H3210, BR[2][1][0], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#3 */
mop(OP_LDWR, 1, &BR[3][0][1], (Ull)kp00[CHIP], 4LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#3 */
mop(OP_LDWR, 1, &BR[3][0][0], (Ull)kp01[CHIP], 4LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#3 */
mop(OP_LDWR, 1, &BR[3][1][1], (Ull)kp02[CHIP], 4LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#3 */
mop(OP_LDWR, 1, &BR[3][1][0], (Ull)kp03[CHIP], 4LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#3 */
mop(OP_LDWR, 1, &BR[3][2][1], (Ull)ip01, oofs, MSK_W0, (Ull)it00, M*(RMGRP+2), 0, 0, (Ull)NULL, M*(RMGRP+2)); /* stage#3 */
:
/****in5*****/
exe(OP_FMA, &AR[48][0], AR[47][0], EXP_H3210, BR[47][2][1], EXP_H3210, BR[47][0][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#48 */
exe(OP_FMA, &AR[48][1], AR[47][1], EXP_H3210, BR[47][2][1], EXP_H3210, BR[47][0][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#48 */
exe(OP_FMA, &AR[48][2], AR[47][2], EXP_H3210, BR[47][2][1], EXP_H3210, BR[47][1][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#48 */
exe(OP_FMA, &AR[48][3], AR[47][3], EXP_H3210, BR[47][2][1], EXP_H3210, BR[47][1][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#48 */
mop(OP_LDWR, 1, &BR[48][0][1], (Ull)kp50[CHIP], 4LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#48 */
mop(OP_LDWR, 1, &BR[48][0][0], (Ull)kp51[CHIP], 4LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#48 */
mop(OP_LDWR, 1, &BR[48][1][1], (Ull)kp52[CHIP], 4LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#48 */
mop(OP_LDWR, 1, &BR[48][1][0], (Ull)kp53[CHIP], 4LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#48 */
mop(OP_LDWR, 1, &BR[48][2][1], (Ull)ip51, oofs, MSK_W0, (Ull)it50, M*(RMGRP+2), 0, 0, (Ull)NULL, M*(RMGRP+2)); /* stage#48 */
:
exe(OP_FMA, &AR[53][0], AR[52][0], EXP_H3210, BR[52][2][1], EXP_H3210, BR[52][0][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#53 */
exe(OP_FMA, &AR[53][1], AR[52][1], EXP_H3210, BR[52][2][1], EXP_H3210, BR[52][0][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#53 */
exe(OP_FMA, &AR[53][2], AR[52][2], EXP_H3210, BR[52][2][1], EXP_H3210, BR[52][1][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#53 */
exe(OP_FMA, &AR[53][3], AR[52][3], EXP_H3210, BR[52][2][1], EXP_H3210, BR[52][1][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#53 */
mop(OP_LDWR, 1, &BR[53][0][1], (Ull)kp50[CHIP], 24LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#53 */
mop(OP_LDWR, 1, &BR[53][0][0], (Ull)kp51[CHIP], 24LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#53 */
mop(OP_LDWR, 1, &BR[53][1][1], (Ull)kp52[CHIP], 24LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#53 */
mop(OP_LDWR, 1, &BR[53][1][0], (Ull)kp53[CHIP], 24LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#53 */
mop(OP_LDWR, 1, &BR[53][2][1], (Ull)ip56, oofs, MSK_W0, (Ull)it50, M*(RMGRP+2), 0, 0, (Ull)NULL, M*(RMGRP+2)); /* stage#53 */
exe(OP_FMA, &AR[54][0], AR[53][0], EXP_H3210, BR[53][2][1], EXP_H3210, BR[53][0][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#54 */
exe(OP_FMA, &AR[54][1], AR[53][1], EXP_H3210, BR[53][2][1], EXP_H3210, BR[53][0][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#54 */
exe(OP_FMA, &AR[54][2], AR[53][2], EXP_H3210, BR[53][2][1], EXP_H3210, BR[53][1][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#54 */
exe(OP_FMA, &AR[54][3], AR[53][3], EXP_H3210, BR[53][2][1], EXP_H3210, BR[53][1][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#54 */
mop(OP_LDWR, 1, &BR[54][0][1], (Ull)kp50[CHIP], 28LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#54 */
mop(OP_LDWR, 1, &BR[54][0][0], (Ull)kp51[CHIP], 28LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#54 */
mop(OP_LDWR, 1, &BR[54][1][1], (Ull)kp52[CHIP], 28LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#54 */
mop(OP_LDWR, 1, &BR[54][1][0], (Ull)kp53[CHIP], 28LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#54 */
mop(OP_LDWR, 1, &BR[54][2][1], (Ull)ip57, oofs, MSK_W0, (Ull)it50, M*(RMGRP+2), 0, 0, (Ull)NULL, M*(RMGRP+2)); /* stage#54 */
exe(OP_FMA, &AR[55][0], AR[54][0], EXP_H3210, BR[54][2][1], EXP_H3210, BR[54][0][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#55 */
exe(OP_FMA, &AR[55][1], AR[54][1], EXP_H3210, BR[54][2][1], EXP_H3210, BR[54][0][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#55 */
exe(OP_FMA, &AR[55][2], AR[54][2], EXP_H3210, BR[54][2][1], EXP_H3210, BR[54][1][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#55 */
exe(OP_FMA, &AR[55][3], AR[54][3], EXP_H3210, BR[54][2][1], EXP_H3210, BR[54][1][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#55 */
mop(OP_LDWR, 1, &BR[55][0][1], (Ull)kp50[CHIP], 32LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#55 */
mop(OP_LDWR, 1, &BR[55][0][0], (Ull)kp51[CHIP], 32LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#55 */
mop(OP_LDWR, 1, &BR[55][1][1], (Ull)kp52[CHIP], 32LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#55 */
mop(OP_LDWR, 1, &BR[55][1][0], (Ull)kp53[CHIP], 32LL, MSK_D0, (Ull)ker, IC*OC*K*K, 0, 0, (Ull)NULL, IC*OC*K*K); /* stage#55 */
mop(OP_LDWR, 1, &BR[55][2][1], (Ull)ip58, oofs, MSK_W0, (Ull)it50, M*(RMGRP+2), 0, 0, (Ull)NULL, M*(RMGRP+2)); /* stage#55 */
exe(OP_FMA, &AR[56][0], AR[55][0], EXP_H3210, BR[55][2][1], EXP_H3210, BR[55][0][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#56 */
exe(OP_FMA, &AR[56][1], AR[55][1], EXP_H3210, BR[55][2][1], EXP_H3210, BR[55][0][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#56 */
exe(OP_FMA, &AR[56][2], AR[55][2], EXP_H3210, BR[55][2][1], EXP_H3210, BR[55][1][1], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#56 */
exe(OP_FMA, &AR[56][3], AR[55][3], EXP_H3210, BR[55][2][1], EXP_H3210, BR[55][1][0], EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#56 */
mop(OP_LDWR, 1, &BR[57][0][1], (Ull)op0[CHIP], oofs, MSK_W0, (Ull)ot0[CHIP], M*RMGRP, 0, 1, (Ull)NULL, M*RMGRP); /* stage#57 */
mop(OP_LDWR, 1, &BR[57][1][1], (Ull)op1[CHIP], oofs, MSK_W0, (Ull)ot1[CHIP], M*RMGRP, 0, 1, (Ull)NULL, M*RMGRP); /* stage#57 */
mop(OP_LDWR, 1, &BR[57][2][1], (Ull)op2[CHIP], oofs, MSK_W0, (Ull)ot2[CHIP], M*RMGRP, 0, 1, (Ull)NULL, M*RMGRP); /* stage#57 */
mop(OP_LDWR, 1, &BR[57][3][1], (Ull)op3[CHIP], oofs, MSK_W0, (Ull)ot3[CHIP], M*RMGRP, 0, 1, (Ull)NULL, M*RMGRP); /* stage#57 */
exe(OP_FAD, &AR[57][0], AR[56][0], EXP_H3210, BR[57][0][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#57 */
exe(OP_FAD, &AR[57][1], AR[56][1], EXP_H3210, BR[57][1][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#57 */
exe(OP_FAD, &AR[57][2], AR[56][2], EXP_H3210, BR[57][2][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#57 */
exe(OP_FAD, &AR[57][3], AR[56][3], EXP_H3210, BR[57][3][1], EXP_H3210, 0LL, EXP_H3210, OP_NOP, 0LL, OP_NOP, 0LL); /* stage#57 */
mop(OP_STWR, 1, &AR[57][0], oofs, (Ull)op0[CHIP], MSK_D0, (Ull)ot0[CHIP], M*RMGRP, 0, 1, (Ull)NULL, M*RMGRP); /* stage#57 8KB */
mop(OP_STWR, 1, &AR[57][1], oofs, (Ull)op1[CHIP], MSK_D0, (Ull)ot1[CHIP], M*RMGRP, 0, 1, (Ull)NULL, M*RMGRP); /* stage#57 8KB */
mop(OP_STWR, 1, &AR[57][2], oofs, (Ull)op2[CHIP], MSK_D0, (Ull)ot2[CHIP], M*RMGRP, 0, 1, (Ull)NULL, M*RMGRP); /* stage#57 8KB */
mop(OP_STWR, 1, &AR[57][3], oofs, (Ull)op3[CHIP], MSK_D0, (Ull)ot3[CHIP], M*RMGRP, 0, 1, (Ull)NULL, M*RMGRP); /* stage#57 8KB */
}
}
}
//EMAX5A end
}
}
}
//EMAX5A drain_dirty_lmm
長いコードに引いてしまった方もいらっしゃると思いますが、一つずつ解説していきます。
まず基本的な考え方は同じと考えてもいいと思います。カーネルと入力画像のアダマール積を取って足す演算を繰り返す感じですね。
ipxx
は入力画像のカーネルとのアダマール積を行う部分のアドレスを指します。kpxx[CHIP]
はカーネルのアドレスです。M
は入力画像のサイズとなります。一チャンネルの画像を分割を行い、アドレスを設定し、rofs
及びcofs
で行及び列オフセットを指定し、その合計を出力のオフセットとし、畳み込み演算を行います。
入力画像のピクセルをユニットに固定し、複数のカーネルに対して積和演算を行い、格納するという認識でいいと思います。一個のチップで格納できる結果は一サイクルあたり四つまでなので、上での積和演算もこの数に縛られます。そのようなことを踏まえて、例題は作られています。また、このようなことから、64ユニットのリニアアレイ型CGRAだけではなく、複数チップによる並列化が重要となってきます。kpxx[CHIP]
及びopxx[CHIP]
がipxx
と違い配列になっているのもこのような理由からです。
カーネルのサイズがもし変わったら?
CUDAの例題ではわかりやすくするために二重for文による説明をしましたが、入力画像のピクセルを各ユニットに固定するIMAXのやり方の特性上、複数のカーネルを用意するしかないと思われます。しかし、大体のばあい固定されているので、実用上問題はないと思われます。
終わりに
Qiitaのエディターの反応速度が非常に遅く、短くして切り上げることになりますが、これで理解できればと願っています。次回は、また別の例題を紹介します。