今日も元気に深層学習の高速化をしている数値計算屋さんの皆さま、こんばんは。
最近の深層学習はほぼConvolutional NNですが、そのConvolution(畳み込み)の計算がim2colを経て行列積になるというのは常識になっていると思います。
では、実際にどんな形の畳み込みがどんな行列積になるのか?
よく問われて毎回算出方法忘れて調べているので、メモがてらまとめておきます。
なお、im2colの実装は一意ではない(行列の転置がありうる)ので、ここではChainerを参照します。
TL; DR
- $N = B W_O H_O$
- $M = C_O $
- $L = C_I W_K H_K$
で、N行L列とL行M列の積になります。
ここで
- $B$:ミニバッチサイズ
- $W_O, H_O$:出力画像の幅・高
- $C_O$:出力画像のチャネル数
- $C_I$:入力画像のチャネル数
- $W_K, H_K$:畳込みカーネルの幅・高
です。実際には$W_O, H_O$は入力画像とカーネルの幅・高およびストライド・パディングによって決まります。計算式は省略(Chainerのconvolutionの説明をご覧ください)。
詳細
im2colについて
im2colそのものについては既に他に優れた説明があるので省略します(例えば以下が参考になります)
ChainerのConvolutionをしてみよう
とりあえず動かしてみましょう
import chainer
import numpy
c_i = 3
w_i = 224
h_i = 224
c_o = 64
b_n = 16
conv = chainer.links.Convolution2D(c_i, c_o, ksize=7, stride=2, pad=3)
x = numpy.empty((b_n, c_i, w_i, h_i), dtype=numpy.float32)
y = conv(x)
print(y.shape) # (16, 64, 112, 112) = (b_n, c_o, w_o, h_o)
と言ったように、(b_n, c_i, w_i, h_i)の画像をConvolution2Dにかけると(b_n, c_o, w_o, h_o)が出てくるようです(それはそう)。
Chainerの中身を見てみよう(テンソル積への変換)
Chainerでは実際にはConvolutionはテンソル積(CPUの場合、NumPyのtensordotに落ちます)。その場所はconvolution_2d.pyにあります。
# https://github.com/chainer/chainer/blob/edc929818b0021543d2b8dc773aa89b41ec06d09/chainer/functions/connection/convolution_2d.py#L134-L138 より引用
col = conv.im2col_cpu(
x, kh, kw, self.sy, self.sx, self.ph, self.pw,
cover_all=self.cover_all, dy=self.dy, dx=self.dx)
y = numpy.tensordot(
col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False)
このcol, W, yの形を形を見るとどんなテンソル積が実行されているのかが分かりますね。
先のコードを実行し、printなりデバッガなりで止めて中身を確認すると
- col.shape= (16, 3, 7, 7, 112, 112) = (B, C_I, K, K, W_O, H_O)
- W.shape= (64, 3, 7, 7) = (C_O, C_I, K, K)
- y.shape= (16, 112, 112, 64) = (B, W_O, H_O, C_O)
でした。Chainerのim2colではテンソル積への変換にしているようですね。
NumPyの中身を見てみよう(行列積への変換)
NumPyのテンソル積のaxes引数、みんな混乱して難しいと思います。完璧に解説するのは趣旨から外れるので他をあたっていただくとして、今回は上記の呼び出しがどう普通の密行列積(いわゆるGEMM)に変換されるかだけに着目します。
コードは、numeric.pyにあります。
# https://github.com/numpy/numpy/blob/09cb2bdeb35faa79939f03cdcc8745f44d5116ca/numpy/core/numeric.py#L1030-L1091 より引用
try:
iter(axes)
except Exception:
axes_a = list(range(-axes, 0))
axes_b = list(range(0, axes))
else:
axes_a, axes_b = axes
try:
na = len(axes_a)
axes_a = list(axes_a)
except TypeError:
axes_a = [axes_a]
na = 1
try:
nb = len(axes_b)
axes_b = list(axes_b)
except TypeError:
axes_b = [axes_b]
nb = 1
a, b = asarray(a), asarray(b)
as_ = a.shape
nda = a.ndim
bs = b.shape
ndb = b.ndim
equal = True
if na != nb:
equal = False
else:
for k in range(na):
if as_[axes_a[k]] != bs[axes_b[k]]:
equal = False
break
if axes_a[k] < 0:
axes_a[k] += nda
if axes_b[k] < 0:
axes_b[k] += ndb
if not equal:
raise ValueError("shape-mismatch for sum")
# Move the axes to sum over to the end of "a"
# and to the front of "b"
notin = [k for k in range(nda) if k not in axes_a]
newaxes_a = notin + axes_a
N2 = 1
for axis in axes_a:
N2 *= as_[axis]
newshape_a = (int(multiply.reduce([as_[ax] for ax in notin])), N2)
olda = [as_[axis] for axis in notin]
notin = [k for k in range(ndb) if k not in axes_b]
newaxes_b = axes_b + notin
N2 = 1
for axis in axes_b:
N2 *= bs[axis]
newshape_b = (N2, int(multiply.reduce([bs[ax] for ax in notin])))
oldb = [bs[axis] for axis in notin]
at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dot(at, bt)
return res.reshape(olda + oldb)
長い・・・ですが、とりあえず最後のatとbtが実際の行列積の形っぽいです。
ということで、newshape_aとnewshape_bが分かれば良さそうですね。
で、今回、以下のことが分かっています
- 入力a, bにはcol, Wが入っており、a.shape=(B, C_I, K, K, W_O, H_O), b.shape=(C_O, C_I, K, K)
- axesには((1, 2, 3), (1, 2, 3))
ということで、これらの情報を固定し展開し、また、oldaとoldbはatとbtから明らかに分かるのでこの計算を除外すると、以下のようになります
# https://github.com/numpy/numpy/blob/09cb2bdeb35faa79939f03cdcc8745f44d5116ca/numpy/core/numeric.py#L1030-L1091 を改変・大幅に省略
a_shape = (B, C_I, K, K, W_O, H_O)
b_shape = (C_O, C_I, K, K)
axes_a = [1, 2, 3]
axes_b = [1, 2, 3]
notin_a = [0, 4, 5] # 0,...,5=dim(a)-1の中でaxes_a=1,2,3でないもの
notin_b = [0] # 0,...,3=dim(b)-1の中でaxes_a=1,2,3でないもの
newshape_a = [B*W_O*H_O, C_I*K*K] # = [a_shapeの中でnotin_a(0, 4, 5)番目だけの積, a_shapeの中でaxes_a(1, 2, 3)番目だけの積])
newshape_b = [C_I*K*K, C_O] # = [b_shapeの中でaxes_b(1, 2, 3)番目だけの積, b_shapeの中でnotin_b(0)番目だけの積])
まとめ
というわけで、
- $N = B W_O H_O$
- $M = C_O $
- $L = C_I W_K H_K$
で、N行L列とL行M列の積になる、でした。