Help us understand the problem. What is going on with this article?

Chainerで学ぶim2colを使った畳み込み演算と密行列積

More than 1 year has passed since last update.

今日も元気に深層学習の高速化をしている数値計算屋さんの皆さま、こんばんは。

最近の深層学習はほぼConvolutional NNですが、そのConvolution(畳み込み)の計算がim2colを経て行列積になるというのは常識になっていると思います。

では、実際にどんな形の畳み込みがどんな行列積になるのか?
よく問われて毎回算出方法忘れて調べているので、メモがてらまとめておきます。

なお、im2colの実装は一意ではない(行列の転置がありうる)ので、ここではChainerを参照します。

TL; DR

  • $N = B W_O H_O$
  • $M = C_O $
  • $L = C_I W_K H_K$

で、N行L列とL行M列の積になります。

ここで

  • $B$:ミニバッチサイズ
  • $W_O, H_O$:出力画像の幅・高
  • $C_O$:出力画像のチャネル数
  • $C_I$:入力画像のチャネル数
  • $W_K, H_K$:畳込みカーネルの幅・高

です。実際には$W_O, H_O$は入力画像とカーネルの幅・高およびストライド・パディングによって決まります。計算式は省略(Chainerのconvolutionの説明をご覧ください)。

詳細

im2colについて

im2colそのものについては既に他に優れた説明があるので省略します(例えば以下が参考になります)

ChainerのConvolutionをしてみよう

とりあえず動かしてみましょう

import chainer
import numpy

c_i = 3
w_i = 224
h_i = 224
c_o = 64
b_n = 16

conv = chainer.links.Convolution2D(c_i, c_o, ksize=7, stride=2, pad=3)
x = numpy.empty((b_n, c_i, w_i, h_i), dtype=numpy.float32)
y = conv(x)
print(y.shape) # (16, 64, 112, 112) = (b_n, c_o, w_o, h_o)

と言ったように、(b_n, c_i, w_i, h_i)の画像をConvolution2Dにかけると(b_n, c_o, w_o, h_o)が出てくるようです(それはそう)。

Chainerの中身を見てみよう(テンソル積への変換)

Chainerでは実際にはConvolutionはテンソル積(CPUの場合、NumPyのtensordotに落ちます)。その場所はconvolution_2d.pyにあります。

# https://github.com/chainer/chainer/blob/edc929818b0021543d2b8dc773aa89b41ec06d09/chainer/functions/connection/convolution_2d.py#L134-L138 より引用
        col = conv.im2col_cpu(
            x, kh, kw, self.sy, self.sx, self.ph, self.pw,
            cover_all=self.cover_all, dy=self.dy, dx=self.dx)
        y = numpy.tensordot(
            col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False)

このcol, W, yの形を形を見るとどんなテンソル積が実行されているのかが分かりますね。
先のコードを実行し、printなりデバッガなりで止めて中身を確認すると

  • col.shape= (16, 3, 7, 7, 112, 112) = (B, C_I, K, K, W_O, H_O)
  • W.shape= (64, 3, 7, 7) = (C_O, C_I, K, K)
  • y.shape= (16, 112, 112, 64) = (B, W_O, H_O, C_O)

でした。Chainerのim2colではテンソル積への変換にしているようですね。

NumPyの中身を見てみよう(行列積への変換)

NumPyのテンソル積のaxes引数、みんな混乱して難しいと思います。完璧に解説するのは趣旨から外れるので他をあたっていただくとして、今回は上記の呼び出しがどう普通の密行列積(いわゆるGEMM)に変換されるかだけに着目します。

コードは、numeric.pyにあります。

# https://github.com/numpy/numpy/blob/09cb2bdeb35faa79939f03cdcc8745f44d5116ca/numpy/core/numeric.py#L1030-L1091 より引用
    try:
        iter(axes)
    except Exception:
        axes_a = list(range(-axes, 0))
        axes_b = list(range(0, axes))
    else:
        axes_a, axes_b = axes
    try:
        na = len(axes_a)
        axes_a = list(axes_a)
    except TypeError:
        axes_a = [axes_a]
        na = 1
    try:
        nb = len(axes_b)
        axes_b = list(axes_b)
    except TypeError:
        axes_b = [axes_b]
        nb = 1

    a, b = asarray(a), asarray(b)
    as_ = a.shape
    nda = a.ndim
    bs = b.shape
    ndb = b.ndim
    equal = True
    if na != nb:
        equal = False
    else:
        for k in range(na):
            if as_[axes_a[k]] != bs[axes_b[k]]:
                equal = False
                break
            if axes_a[k] < 0:
                axes_a[k] += nda
            if axes_b[k] < 0:
                axes_b[k] += ndb
    if not equal:
        raise ValueError("shape-mismatch for sum")

    # Move the axes to sum over to the end of "a"
    # and to the front of "b"
    notin = [k for k in range(nda) if k not in axes_a]
    newaxes_a = notin + axes_a
    N2 = 1
    for axis in axes_a:
        N2 *= as_[axis]
    newshape_a = (int(multiply.reduce([as_[ax] for ax in notin])), N2)
    olda = [as_[axis] for axis in notin]

    notin = [k for k in range(ndb) if k not in axes_b]
    newaxes_b = axes_b + notin
    N2 = 1
    for axis in axes_b:
        N2 *= bs[axis]
    newshape_b = (N2, int(multiply.reduce([bs[ax] for ax in notin])))
    oldb = [bs[axis] for axis in notin]

    at = a.transpose(newaxes_a).reshape(newshape_a)
    bt = b.transpose(newaxes_b).reshape(newshape_b)
    res = dot(at, bt)
    return res.reshape(olda + oldb)

長い・・・ですが、とりあえず最後のatとbtが実際の行列積の形っぽいです。
ということで、newshape_aとnewshape_bが分かれば良さそうですね。

で、今回、以下のことが分かっています

  • 入力a, bにはcol, Wが入っており、a.shape=(B, C_I, K, K, W_O, H_O), b.shape=(C_O, C_I, K, K)
  • axesには((1, 2, 3), (1, 2, 3))

ということで、これらの情報を固定し展開し、また、oldaとoldbはatとbtから明らかに分かるのでこの計算を除外すると、以下のようになります

# https://github.com/numpy/numpy/blob/09cb2bdeb35faa79939f03cdcc8745f44d5116ca/numpy/core/numeric.py#L1030-L1091 を改変・大幅に省略
    a_shape = (B, C_I, K, K, W_O, H_O)
    b_shape = (C_O, C_I, K, K)
    axes_a = [1, 2, 3]
    axes_b = [1, 2, 3]
    notin_a = [0, 4, 5] # 0,...,5=dim(a)-1の中でaxes_a=1,2,3でないもの
    notin_b = [0] # 0,...,3=dim(b)-1の中でaxes_a=1,2,3でないもの

    newshape_a = [B*W_O*H_O, C_I*K*K] # = [a_shapeの中でnotin_a(0, 4, 5)番目だけの積, a_shapeの中でaxes_a(1, 2, 3)番目だけの積])
    newshape_b = [C_I*K*K, C_O] # = [b_shapeの中でaxes_b(1, 2, 3)番目だけの積, b_shapeの中でnotin_b(0)番目だけの積])

まとめ

というわけで、

  • $N = B W_O H_O$
  • $M = C_O $
  • $L = C_I W_K H_K$

で、N行L列とL行M列の積になる、でした。

aokomoriuta
土木工学の水理学・計算力学な人で、主にソフトウェアの高速化をやっています。
https://aokomoriuta.bitbucket.io/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした