Winograd 畳み込みアルゴリズム(以下、Winograd アルゴリズム) は、畳み込み演算を高速に計算する方法の一つです。2016年に Fast Algorithms for Convolutional Neural Networks という論文で提案されて以来、畳み込みニューラルネットワークの高速化に広く使われています。NVIDIA の cuDNN、AMD の MIOpen といったライブラリには、アセンブリで書かれたWinograd アルゴリズムの実装があり、入力が適用条件を満たす場合はほぼ Winograd アルゴリズムが使われるようです。
TVM の CUDA バックエンドには、畳み込みの実装には定義通りの "Direct" アルゴリズムしかありません。そこで、TVM による Winograd アルゴリズムの GPU 実装をしようと思い立ち、先日いろいろと試してみました。結果的に、様々な入力について、TVM 本家 の Direct アルゴリズムよりも高速な畳み込みの実装にたどりつくことができました。本記事では、そのときに得られた知見を共有したいと思います。最初に実装した遅いカーネルをどのように高速化したかを、ステップバイステップで示します。本記事を通じて、TVM を使うとGPUコードを書くのがいかに簡単になるか、伝われば幸いです。
Winograd アルゴリズム
まず、以降で必要となる最低限の事柄についてまとめます。Winograd アルゴリズムの詳細については、Fast Algorithms for Convolutional Neural Networks をご覧下さい。以降、記号は上論文と同じものを使います。本記事では、上論文の 4.1節に説明されている Winograd F(2x2, 3x3) アルゴリズムの実装、高速化をします。
F(2x2, 3x3) アルゴリズムでは、入力画像を 4 x 4 のタイルに分割し、タイルごとに 2 x 2 の出力を計算します。隣り合うタイルは 2 ピクセル分重なるようにタイリングされます。それぞれの 4 x 4 のタイル $d$ は、行列 $B$,
B^T = \left[
\begin{array}{rrr}
1 & 0 & -1 & 0 \\
0 & 1 & 1 & 0\\
0 & -1 & 1 & 0 \\
0 & 1 & 0 & -1
\end{array}
\right]
によって、$B^T d B$ と変換されます。同様に、3 x 3 入力フィルタ $g$ は 以下の行列 $G$ によって、$GgG^T$ と変換されます。
G = \left[
\begin{array}{rrr}
1 & 0 & 0 \\
0.5 & 0.5 & 0.5 \\
0.5 & -0.5 & 0.5 \\
0 & 0 & 1
\end{array}
\right]
画像タイルの変換 $B^T d B$、フィルタの変換 $GgG^T$ はそれぞれ 4 x 4 の行列です。Winograd アルゴリズムでは、これらの行列の要素積をとり、さらに4 x 2 の行列 $A^T$,
A^T= \left[
\begin{array}{rrr}
1 & 1 & 1 & 0 \\
0 & 1 & -1 & -1
\end{array}
\right]
を左右からかけて、2 x 2 の出力
$$ Y = A^T \bigl[[GgG^T] \odot [B^T d B]\bigr] A $$
を得ます。
ここまでは 1 つの 4 x 4 入力タイルから、2 x 2 の出力タイルを得る手順を説明しました。畳み込みネットワークでは、入力は縦H・横W に加えてチャネル方向C、バッチ方向 Nの次元を持ち、出力はバッチ N, フィルター数 K、 縦H、横 W からなります。出力フィルターの 1つ分の 2 x 2 タイルは、 チャネル C ごとに計算された 2 x 2 のタイル $Y$ を足し合わせることで計算されます。出力の全 N x K x (タイル数) 分の 2 x 2 タイルは、上記の手順で独立に計算されます。
Fast Algorithms for Convolutional Neural Networks では、上記の全出力タイル分の計算を、行列積でまとめて表現することで高速に行う方法を提案しています。4 x 4 の要素積 $[[GgG^T] \odot [B^T d B]\bigr]$ をチャネル方向に足し合わせる計算は、要素ごとに内積を計算していると見立てることができます。さらに、入力・出力の全タイルをそれぞれ適切な行列に詰め、それらの積を計算することによって、全出力タイル分の内積を一度に計算することができます。詳細については論文をご覧ください。ここでは、論文の式$(13)$、
$$ M^{(\xi,\nu)} = U^{(\xi,\nu)} V^{(\xi,\nu)} (13)$$
について補足しておきます。ここで計算されている $M^{(\xi,\nu)}$ は行列ですが、これは 4 x 4 x (出力フィルタ数 K) x (全バッチ分のタイル数) の4次元配列 $M$ を、0 から 3の添え字のペア $(\xi, \nu)$ でインデックスしたものです。$U^{(\xi,\nu)}$、$V^{(\xi,\nu)}$についても同様で、 $U$ は 4 x 4 x (出力フィルタ数 K) x (入力フィルタ数 C) の配列、$V$ は4 x 4 x (入力フィルタ数 C) x (全バッチ分のタイル数) の配列です。$U^{(\xi,\nu)}$は、4 x 4 のフィルタの変換 $GgG^T$ の、インデックス $(\xi, \nu)$ における値を、出力フィルタ数 K 個、入力フィルタ数 C 個並べてできた行列になります。$V^{(\xi,\nu)}$ についても同様で、こちらは入力画像の変換 $B^T d B$ から、インデックス $(\xi, \nu)$の値を入力フィルタ数 C 個、全バッチ分のタイル数 N x K x (タイル数) 個並べたものになります。このように行列をつくることで、元々やろうとしていた 4 x 4 の要素積とチャネル方向の和を、4 x 4 個分の独立した行列積で全出力タイル分計算することができます。行列積には、理論ピークに近い性能を出せるとても効率的な GPU 実装があるので、式 $(13)$ の計算は高速に行うことができます。また、複数の行列積を一度に行うことはよくあるようで、Batched GEMM と呼ばれています。例えば、cuBLAS にはcublasSgemmBatched(...)という API があり、これを使って 4 x 4 個分の式$(13)$ の計算をさらに高速に行うことができます。
TVM による Winograd アルゴリズムの定義
ここからは、TVM による Winograd アルゴリズムの実装に移ります。TVM は、Halideと同じように、なにを計算するのか(what) と、どう計算するのか (how) を分離する、というデザインを採用しています。そのため、まずは what にあたる Winograd アルゴリズムの定義を、TVM の API で記述します。
冒頭で述べたように、TVM の CUDA バックエンドには Winograd アルゴリズムの実装はありませんが、ARM Mali GPU バックエンドには実装があります。ただ、Mali バックエンドの実装を CUDA から使うことはできないので、Mali の実装を参考にして NVIDIA の GPU に適したコードに書き換えました。ここからコード断片を交えて説明しますが、コード全体はこのリポジトリにおいてあります。
まず、アルゴリズムの入力は画像データとフィルターですが、TVM のようなデプロイ用のコンパイラスタックにとっては、フィルターの値はコンパイル時に渡された値からデプロイ時まで変化しないので、フィルターの変換はコンパイル時に事前計算しておくものとします。そのため、ここで定義するWinograd アルゴリズムの入力は、(N, C, H, W) の入力画像と、(4, 4, K, C) のフィルター変換 U からなります。
はじめに、各定数の値を設定します。入力のパディングもここで行います。
import numpy as np
import tvm
import topi
from topi import util
from topi.nn import pad
# data は N x C x H x W の入力、U は 4 x 4 x K x C のフィルター変換
data = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
U = tvm.placeholder((4, 4, num_filter, in_channel), name='U')
N, C H, W = [util.get_const_int(x) for x in data.shape]
_, _, K, C = [util.get_const_int(x) for x in U.shape]
m = 2
r = 3
alpha = m + r - 1 # タイルの幅. F(2x2, 3x3) では alpha = 4
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW # 全タイル数
data_pad = pad(data, (0, 0, 1, 1), name="data_pad")
入力の変換行列 $B$ と、逆変換行列 $A$ を定義します。関数 const_array() は、Numpy 配列から TVM の定数配列をつくる関数です。
B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]
])
B = const_array(B_data, 'B')
A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1],
])
A = const_array(A_data, 'A')
これで、Winograd アルゴリズムの定義をする準備ができました。最初のステップは、入力の(N, C, H, W) 配列を、(チャネル数 C、全タイル数 P, 4, 4) のタイル状に並び替えます。さらに、タイルごとに入力変換 $B^T d B$ を計算し、配列 $V$ を定義します。
d = tvm.compute((C, P, alpha, alpha),
lambda c, b, eps, nu:
tvm.select(b < P, \\
data_pad[b // (nH*nW)][c][b// nW % nH * m + eps][b % nW * m + nu],\\
tvm.const(0, data_pad.dtype)), name='d')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P, C), lambda eps, nu, b, c:
tvm.sum(d[c][b][r_eps][r_nu] * B[r_eps][eps] * B[r_nu][nu],
axis=[r_eps, r_nu]), name='V')
入力として与えられた $U$ と、先ほど定義した $V$ から、4 x 4 個分の行列積を行い、配列 $M$ を定義します。
c = tvm.reduce_axis((0, C), name='c')
M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b:
tvm.sum(U[eps][nu][k][c] *
V[eps][nu][b][c], axis=c), name='M')
ここまでで、4 x 4 の要素積とチャネル方向の和が計算されました。最後に、4 x 4 のタイルに左右からそれぞれ行列 $A^T$, $A$ をかけ、2 x 2 の出力タイルを得ます。全バッチ、出力チャネル、タイル数分の 2 x 2 タイル を(N, K, H, W) の形式に並び替えて、最終的な出力を得ます。
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
tvm.sum(M[r_eps][r_nu][k][n * nH * nW + (h//m) * nW + w//m] * \\
A[r_eps][h % m] * A[r_nu][w % m],
axis=[r_eps, r_nu]), name='output')
これで、Winograd アルゴリズムの定義ができました。
ナイーブな GPU 実装
ここからは、上で定義した Winograd アルゴリズムを、NVIDIA GPU 向けにどのように並列化するか説明します。Halide や TVM では、このようなアルゴリズムの how にあたる部分をスケジュールと呼んでいます。
上で定義したアルゴリズムは、主に以下の 3 つのステップからなります。
- 入力の変換 $V$ の計算
- 4 x 4 個分の行列積 $M$ の計算
- 出力の逆変換
今回は、上記 3 つのステップを、別々の GPU カーネルとしてスケジューリングします。最初は、考えうる最もシンプルなカーネルを実装しましょう。GPU カーネルとして最も単純なのは、それぞれのスレッドが一つの出力のみを独立に計算する、というものです。最初のステップである入力の変換 $V$ を、そのようにスケジュールするには、以下のように記述します。TVM のスケジューリング用の API については、公式ホームページにチュートリアルがありますので、あわせてご覧ください。
num_thread = 16
s = tvm.create_schedule([output.op])
s[B].compute_inline()
s[d].compute_inline() # 入力画像のタイリングは V 内のループ中で行う
eps, nu, p, c = s[V].op.axis
po, pi = s[V].split(p, factor=num_thread)
co, ci = s[V].split(c, factor=num_thread)
s[V].reorder(eps, nu, po, co, pi, ci)
fused = s[V].fuse(eps, nu, po, co)
s[V].bind(pi, tvm.thread_axis("threadIdx.y"))
s[V].bind(ci, tvm.thread_axis("threadIdx.x"))
s[V].bind(fused, tvm.thread_axis("blockIdx.x"))
同様に、行列積 $M$ と出力の逆変換のスケジュールは、以下のように記述できます。
eps, nu, k, p = s[M].op.axis
ko, ki = s[M].split(k, factor=num_thread)
po, pi = s[M].split(p, factor=num_thread)
z = s[M].fuse(eps, nu)
s[M].bind(ki, tvm.thread_axis("threadIdx.y"))
s[M].bind(pi, tvm.thread_axis("threadIdx.x"))
s[M].bind(ko, tvm.thread_axis("blockIdx.y"))
s[M].bind(po, tvm.thread_axis("blockIdx.x"))
s[M].bind(z, tvm.thread_axis("blockIdx.z"))
s[A].compute_inline()
n, k, h, w = s[output].op.axis
ho, hi = s[output].split(h, factor=num_thread)
wo, wi = s[output].split(w, factor=num_thread)
s[output].reorder(k, ho, wo, hi, wi)
fused = s[output].fuse(k, ho, wo)
s[output].bind(hi, tvm.thread_axis("threadIdx.y"))
s[output].bind(wi, tvm.thread_axis("threadIdx.x"))
s[output].bind(fused, tvm.thread_axis("blockIdx.x"))
これで、TVM による Winograd アルゴリズムの最初の GPU 実装ができました。本記事で実装する全コードが含まれるリポジトリには、ここで定義したナイーブなスケジュールを含むブランチがあります。このブランチのコードを実行すると、以下のような出力を得ます。
$ python wino_test.py
Winograd: 23.614 msec, Reference: 2.138 msec
"Winograd" は今回実装した Winograd アルゴリズムの実行時間、"Reference" は TVM 本家の Direct Convolution の実行時間です。見ての通り、Winograd のほうが10倍以上遅いですね。ここで実装したのはナイーブな並列化なので、当然といえます。また、上のコードでは Winograd アルゴリズムの出力結果が正しいかどうかのチェックもされています。なにもエラーが出ていないということは、結果が正しいということです。
次回
ここまで、Winograd アルゴリズムの概要、TVM によるアルゴリズムの定義、ナイーブな GPU 実装、について説明しました。次回の記事では、このナイーブなスケジュールを高速化し、最終的に TVM 本家の Direct Convolution よりも速い実装が得られるまでの過程を示します。
現時点で一番高速な実装はこのリポジトリにあります。