LoginSignup
0
2

More than 3 years have passed since last update.

Chainerで学ぶim2colを使った畳み込み演算と密行列積

Last updated at Posted at 2019-08-30

今日も元気に深層学習の高速化をしている数値計算屋さんの皆さま、こんばんは。

最近の深層学習はほぼConvolutional NNですが、そのConvolution(畳み込み)の計算がim2colを経て行列積になるというのは常識になっていると思います。

では、実際にどんな形の畳み込みがどんな行列積になるのか?
よく問われて毎回算出方法忘れて調べているので、メモがてらまとめておきます。

なお、im2colの実装は一意ではない(行列の転置がありうる)ので、ここではChainerを参照します。

TL; DR

  • $N = B W_O H_O$
  • $M = C_O $
  • $L = C_I W_K H_K$

で、N行L列とL行M列の積になります。

ここで

  • $B$:ミニバッチサイズ
  • $W_O, H_O$:出力画像の幅・高
  • $C_O$:出力画像のチャネル数
  • $C_I$:入力画像のチャネル数
  • $W_K, H_K$:畳込みカーネルの幅・高

です。実際には$W_O, H_O$は入力画像とカーネルの幅・高およびストライド・パディングによって決まります。計算式は省略(Chainerのconvolutionの説明をご覧ください)。

詳細

im2colについて

im2colそのものについては既に他に優れた説明があるので省略します(例えば以下が参考になります)

ChainerのConvolutionをしてみよう

とりあえず動かしてみましょう

import chainer
import numpy

c_i = 3
w_i = 224
h_i = 224
c_o = 64
b_n = 16

conv = chainer.links.Convolution2D(c_i, c_o, ksize=7, stride=2, pad=3)
x = numpy.empty((b_n, c_i, w_i, h_i), dtype=numpy.float32)
y = conv(x)
print(y.shape) # (16, 64, 112, 112) = (b_n, c_o, w_o, h_o)

と言ったように、(b_n, c_i, w_i, h_i)の画像をConvolution2Dにかけると(b_n, c_o, w_o, h_o)が出てくるようです(それはそう)。

Chainerの中身を見てみよう(テンソル積への変換)

Chainerでは実際にはConvolutionはテンソル積(CPUの場合、NumPyのtensordotに落ちます)。その場所はconvolution_2d.pyにあります。

# https://github.com/chainer/chainer/blob/edc929818b0021543d2b8dc773aa89b41ec06d09/chainer/functions/connection/convolution_2d.py#L134-L138 より引用
        col = conv.im2col_cpu(
            x, kh, kw, self.sy, self.sx, self.ph, self.pw,
            cover_all=self.cover_all, dy=self.dy, dx=self.dx)
        y = numpy.tensordot(
            col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False)

このcol, W, yの形を形を見るとどんなテンソル積が実行されているのかが分かりますね。
先のコードを実行し、printなりデバッガなりで止めて中身を確認すると

  • col.shape= (16, 3, 7, 7, 112, 112) = (B, C_I, K, K, W_O, H_O)
  • W.shape= (64, 3, 7, 7) = (C_O, C_I, K, K)
  • y.shape= (16, 112, 112, 64) = (B, W_O, H_O, C_O)

でした。Chainerのim2colではテンソル積への変換にしているようですね。

NumPyの中身を見てみよう(行列積への変換)

NumPyのテンソル積のaxes引数、みんな混乱して難しいと思います。完璧に解説するのは趣旨から外れるので他をあたっていただくとして、今回は上記の呼び出しがどう普通の密行列積(いわゆるGEMM)に変換されるかだけに着目します。

コードは、numeric.pyにあります。

# https://github.com/numpy/numpy/blob/09cb2bdeb35faa79939f03cdcc8745f44d5116ca/numpy/core/numeric.py#L1030-L1091 より引用
    try:
        iter(axes)
    except Exception:
        axes_a = list(range(-axes, 0))
        axes_b = list(range(0, axes))
    else:
        axes_a, axes_b = axes
    try:
        na = len(axes_a)
        axes_a = list(axes_a)
    except TypeError:
        axes_a = [axes_a]
        na = 1
    try:
        nb = len(axes_b)
        axes_b = list(axes_b)
    except TypeError:
        axes_b = [axes_b]
        nb = 1

    a, b = asarray(a), asarray(b)
    as_ = a.shape
    nda = a.ndim
    bs = b.shape
    ndb = b.ndim
    equal = True
    if na != nb:
        equal = False
    else:
        for k in range(na):
            if as_[axes_a[k]] != bs[axes_b[k]]:
                equal = False
                break
            if axes_a[k] < 0:
                axes_a[k] += nda
            if axes_b[k] < 0:
                axes_b[k] += ndb
    if not equal:
        raise ValueError("shape-mismatch for sum")

    # Move the axes to sum over to the end of "a"
    # and to the front of "b"
    notin = [k for k in range(nda) if k not in axes_a]
    newaxes_a = notin + axes_a
    N2 = 1
    for axis in axes_a:
        N2 *= as_[axis]
    newshape_a = (int(multiply.reduce([as_[ax] for ax in notin])), N2)
    olda = [as_[axis] for axis in notin]

    notin = [k for k in range(ndb) if k not in axes_b]
    newaxes_b = axes_b + notin
    N2 = 1
    for axis in axes_b:
        N2 *= bs[axis]
    newshape_b = (N2, int(multiply.reduce([bs[ax] for ax in notin])))
    oldb = [bs[axis] for axis in notin]

    at = a.transpose(newaxes_a).reshape(newshape_a)
    bt = b.transpose(newaxes_b).reshape(newshape_b)
    res = dot(at, bt)
    return res.reshape(olda + oldb)

長い・・・ですが、とりあえず最後のatとbtが実際の行列積の形っぽいです。
ということで、newshape_aとnewshape_bが分かれば良さそうですね。

で、今回、以下のことが分かっています

  • 入力a, bにはcol, Wが入っており、a.shape=(B, C_I, K, K, W_O, H_O), b.shape=(C_O, C_I, K, K)
  • axesには((1, 2, 3), (1, 2, 3))

ということで、これらの情報を固定し展開し、また、oldaとoldbはatとbtから明らかに分かるのでこの計算を除外すると、以下のようになります

# https://github.com/numpy/numpy/blob/09cb2bdeb35faa79939f03cdcc8745f44d5116ca/numpy/core/numeric.py#L1030-L1091 を改変・大幅に省略
    a_shape = (B, C_I, K, K, W_O, H_O)
    b_shape = (C_O, C_I, K, K)
    axes_a = [1, 2, 3]
    axes_b = [1, 2, 3]
    notin_a = [0, 4, 5] # 0,...,5=dim(a)-1の中でaxes_a=1,2,3でないもの
    notin_b = [0] # 0,...,3=dim(b)-1の中でaxes_a=1,2,3でないもの

    newshape_a = [B*W_O*H_O, C_I*K*K] # = [a_shapeの中でnotin_a(0, 4, 5)番目だけの積, a_shapeの中でaxes_a(1, 2, 3)番目だけの積])
    newshape_b = [C_I*K*K, C_O] # = [b_shapeの中でaxes_b(1, 2, 3)番目だけの積, b_shapeの中でnotin_b(0)番目だけの積])

まとめ

というわけで、

  • $N = B W_O H_O$
  • $M = C_O $
  • $L = C_I W_K H_K$

で、N行L列とL行M列の積になる、でした。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2