機械学習
MachineLearning
DeepLearning
Chainer
PyTorch
OriginalVASILYDay 20

Chainer版InceptionV3の学習済みモデルを作りました

はじめに

諸事情によりInceptionV3の学習済みモデルが必要になりました。
TensorFlowとPyTorchは重みが提供されているのですが、私が普段使いしているChainerは軽く探して見つけることができませんでした。
InceptionV3をどうしてもChainerで実装したかったのですが、私はGPUを1024基持っていないので、ImageNetを15分で学習することを諦めてPyTorchの学習済みモデルの重みをChainerのモデルに転写することにしました。
onnx-chainer??聞いたことないですね。

PyTorchの学習済みモデルについて

PyTorchのInceptionV3学習済みモデルはtorchvisionに最初から備わっています。展開するとわかるのですが、PyTorchでは学習済みモデルをOrderedDictとして固めています。

>>> import torch.utils.model_zoo as model_zoo
>>> tmodel = model_zoo.load_url('https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth')
>>> type(tmodel)
<class 'collections.OrderedDict'>
>>>                                                                                                    

中身はtorch.FloatTensorで、手前のlayerから順に格納されているようです。.numpy()とすれば中身をnumpy.ndarrayに変換できます。

>>> tmodel['fc.weight'].numpy()

Chainerモデルについて

一方、Chainerではモデルの重みをchainer.Linkのパラメータとして定義しています。例えば、convの重みは以下のように取得できます。

>>> conv.W.data
>>> conv.b.data

モデル定義の際にW,bに値を渡してやれば、任意の値でモデルを初期化できます。

やること

つまり、PyTorchの学習済みモデルを展開してnumpy.ndarrayに変換し、chainer.Linkの対応する変数に渡してやれば、学習済みchainer.Chainが手に入ります。

ChainerでInceptionV3を実装

重みを転写するにあたり、簡単のためにlayerの構造や名前をPyTorch実装と完全に揃えておきます。

inception_v3.py
import chainer
from chainer import functions as F
from chainer import links as L


class InceptionV3(chainer.Chain):
    def __init__(self):
        super(InceptionV3, self).__init__()
        with self.init_scope():
            self.Conv2d_1a_3x3 = BasicConv2d(3, 32, ksize=3, stride=2)
            self.Conv2d_2a_3x3 = BasicConv2d(32, 32, ksize=3)
            self.Conv2d_2b_3x3 = BasicConv2d(32, 64, ksize=3, pad=1)
            self.Conv2d_3b_1x1 = BasicConv2d(64, 80, ksize=1)
            self.Conv2d_4a_3x3 = BasicConv2d(80, 192, ksize=3)
            self.Mixed_5b = InceptionA(192, pool_features=32)
            self.Mixed_5c = InceptionA(256, pool_features=64)
            self.Mixed_5d = InceptionA(288, pool_features=64)
            self.Mixed_6a = InceptionB(288)
            self.Mixed_6b = InceptionC(768, channels_7x7=128)
            self.Mixed_6c = InceptionC(768, channels_7x7=160)
            self.Mixed_6d = InceptionC(768, channels_7x7=160)
            self.Mixed_6e = InceptionC(768, channels_7x7=192)
            self.Mixed_7a = InceptionD(768)
            self.Mixed_7b = InceptionE(1280)
            self.Mixed_7c = InceptionE(2048)

    def __call__(self, x):
        h = x
        # 299 x 299 x 3
        h = self.Conv2d_1a_3x3(h)
        # 149 x 149 x 32
        h = self.Conv2d_2a_3x3(h)
        # 147 x 147 x 32
        h = self.Conv2d_2b_3x3(h)
        # 147 x 147 x 64
        h = F.max_pooling_2d(h, ksize=3, stride=2)
        # 73 x 73 x 64
        h = self.Conv2d_3b_1x1(h)
        # 73 x 73 x 80
        h = self.Conv2d_4a_3x3(h)
        # 71 x 71 x 192
        h = F.max_pooling_2d(h, ksize=3, stride=2)
        # 35 x 35 x 192
        h = self.Mixed_5b(h)
        # 35 x 35 x 256
        h = self.Mixed_5c(h)
        # 35 x 35 x 288
        h = self.Mixed_5d(h)
        # 35 x 35 x 288
        h = self.Mixed_6a(h)
        # 17 x 17 x 768
        h = self.Mixed_6b(h)
        # 17 x 17 x 768
        h = self.Mixed_6c(h)
        # 17 x 17 x 768
        h = self.Mixed_6d(h)
        # 17 x 17 x 768
        h = self.Mixed_6e(h)
        # 17 x 17 x 768
        h = self.Mixed_7a(h)
        # 8 x 8 x 1280
        h = self.Mixed_7b(h)
        # 8 x 8 x 2048
        h = self.Mixed_7c(h)
        # 8 x 8 x 2048
        h = F.average_pooling_2d(h, ksize=8)
        # 1 x 1 x 2048
        h = F.dropout(h)
        return h


class InceptionA(chainer.Chain):
    def __init__(self, in_channels, pool_features):
        super(InceptionA, self).__init__()
        with self.init_scope():
            self.branch1x1 = BasicConv2d(in_channels, 64, ksize=1)
            self.branch5x5_1 = BasicConv2d(in_channels, 48, ksize=1)
            self.branch5x5_2 = BasicConv2d(48, 64, ksize=5, pad=2)
            self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, ksize=1)
            self.branch3x3dbl_2 = BasicConv2d(64, 96, ksize=3, pad=1)
            self.branch3x3dbl_3 = BasicConv2d(96, 96, ksize=3, pad=1)
            self.branch_pool = BasicConv2d(in_channels, pool_features, ksize=1)

    def __call__(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.average_pooling_2d(x, ksize=3, stride=1, pad=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = (branch1x1, branch5x5, branch3x3dbl, branch_pool)
        return F.concat(outputs, axis=1)


class InceptionB(chainer.Chain):
    def __init__(self, in_channels):
        super(InceptionB, self).__init__()
        with self.init_scope():
            self.branch3x3 = BasicConv2d(in_channels, 384, ksize=3, stride=2)

            self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, ksize=1)
            self.branch3x3dbl_2 = BasicConv2d(64, 96, ksize=3, pad=1)
            self.branch3x3dbl_3 = BasicConv2d(96, 96, ksize=3, stride=2)

    def __call__(self, x):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pooling_2d(x, ksize=3, stride=2)

        outputs = (branch3x3, branch3x3dbl, branch_pool)
        return F.concat(outputs, axis=1)


class InceptionC(chainer.Chain):
    def __init__(self, in_channels, channels_7x7):
        super(InceptionC, self).__init__()
        with self.init_scope():
            self.branch1x1 = BasicConv2d(in_channels, 192, ksize=1)

            c7 = channels_7x7
            self.branch7x7_1 = BasicConv2d(in_channels, c7, ksize=1)
            self.branch7x7_2 = BasicConv2d(c7, c7, ksize=(1, 7), pad=(0, 3))
            self.branch7x7_3 = BasicConv2d(c7, 192, ksize=(7, 1), pad=(3, 0))

            self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, ksize=1)
            self.branch7x7dbl_2 = BasicConv2d(c7, c7, ksize=(7, 1),
                    pad=(3, 0))
            self.branch7x7dbl_3 = BasicConv2d(c7, c7, ksize=(1, 7),
                    pad=(0, 3))
            self.branch7x7dbl_4 = BasicConv2d(c7, c7, ksize=(7, 1),
                    pad=(3, 0))
            self.branch7x7dbl_5 = BasicConv2d(c7, 192, ksize=(1, 7),
                    pad=(0, 3))

            self.branch_pool = BasicConv2d(in_channels, 192, ksize=1)

    def __call__(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.average_pooling_2d(x, ksize=3, stride=1, pad=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = (branch1x1, branch7x7, branch7x7dbl, branch_pool)
        return F.concat(outputs, axis=1)


class InceptionD(chainer.Chain):
    def __init__(self, in_channels):
        super(InceptionD, self).__init__()
        with self.init_scope():
            self.branch3x3_1 = BasicConv2d(in_channels, 192, ksize=1)
            self.branch3x3_2 = BasicConv2d(192, 320, ksize=3, stride=2)

            self.branch7x7x3_1 = BasicConv2d(in_channels, 192, ksize=1)
            self.branch7x7x3_2 = BasicConv2d(192, 192, ksize=(1, 7), pad=(0, 3))
            self.branch7x7x3_3 = BasicConv2d(192, 192, ksize=(7, 1), pad=(3, 0))
            self.branch7x7x3_4 = BasicConv2d(192, 192, ksize=3, stride=2)

    def __call__(self, x):
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pooling_2d(x, ksize=3, stride=2)
        outputs = (branch3x3, branch7x7x3, branch_pool)
        return F.concat(outputs, axis=1)


class InceptionE(chainer.Chain):
    def __init__(self, in_channels):
        super(InceptionE, self).__init__()
        with self.init_scope():
            self.branch1x1 = BasicConv2d(in_channels, 320, ksize=1)

            self.branch3x3_1 = BasicConv2d(in_channels, 384, ksize=1)
            self.branch3x3_2a = BasicConv2d(384, 384, ksize=(1, 3), pad=(0, 1))
            self.branch3x3_2b = BasicConv2d(384, 384, ksize=(3, 1), pad=(1, 0))

            self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, ksize=1)
            self.branch3x3dbl_2 = BasicConv2d(448, 384, ksize=3, pad=1)
            self.branch3x3dbl_3a = BasicConv2d(384, 384, ksize=(1, 3), pad=(0, 1))
            self.branch3x3dbl_3b = BasicConv2d(384, 384, ksize=(3, 1), pad=(1, 0))

            self.branch_pool = BasicConv2d(in_channels, 192, ksize=1)

    def __call__(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
                self.branch3x3_2a(branch3x3),
                self.branch3x3_2b(branch3x3),
                ]
        branch3x3 = F.concat(branch3x3, axis=1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
                self.branch3x3dbl_3a(branch3x3dbl),
                self.branch3x3dbl_3b(branch3x3dbl),
                ]
        branch3x3dbl = F.concat(branch3x3dbl, axis=1)

        branch_pool = F.average_pooling_2d(x, ksize=3, stride=1, pad=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = (branch1x1, branch3x3, branch3x3dbl, branch_pool)
        return F.concat(outputs, axis=1)


class BasicConv2d(chainer.Chain):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        with self.init_scope():
            self.conv = L.Convolution2D(in_channels, out_channels, nobias=True,
                    **kwargs)
            self.bn = L.BatchNormalization(out_channels, eps=0.001)

    def __call__(self, x):
        h = self.conv(x)
        h = self.bn(h)
        return F.relu(h)

ちなみに、側鎖は不要だったので定義していません。

転写する

後は1層ずつ値をコピーするだけです。以下のスクリプトを実行しました。
一度しか使わないはずのコードなので、べた書きです。

copy_inception.py
import chainer
from chainer import Variable
from chainermodel.inception_v3 import InceptionV3

import torch.utils.model_zoo as model_zoo

CONV = 'Convolution2D'
BN   = 'BatchNormalization'
BASIC_CONV = 'BasicConv2d'

def copy_array(t, c):
    assert t.shape == c.shape
    c.data = t 

def copy_conv(t, c, name):
    assert c.__class__.__name__ == CONV
    copy_array(t[name + '.conv.weight'].numpy(), c.W)

def copy_bn(t, c, name):
    assert c.__class__.__name__ == BN
    copy_array(t[name + '.bn.weight'].numpy()      , c.gamma)
    copy_array(t[name + '.bn.bias'].numpy()        , c.beta)
    copy_array(t[name + '.bn.running_mean'].numpy(), c.avg_mean)
    copy_array(t[name + '.bn.running_var'].numpy() , c.avg_var)

def copy_basic_conv(t, c, name):
    assert c.__class__.__name__ == BASIC_CONV
    copy_conv(t, c.conv, name)
    copy_bn(t, c.bn, name)

def copy_mixed(t, c, name):
    for childname in c._children:
        child = c[childname]
        assert child.__class__.__name__ == BASIC_CONV
        copy_basic_conv(t, child, name + '.' + childname)

def main():
    tmodel = model_zoo.load_url('https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth')
    cmodel = InceptionV3()

    copy_basic_conv(tmodel, cmodel.Conv2d_1a_3x3, 'Conv2d_1a_3x3')
    copy_basic_conv(tmodel, cmodel.Conv2d_2a_3x3, 'Conv2d_2a_3x3')
    copy_basic_conv(tmodel, cmodel.Conv2d_2b_3x3, 'Conv2d_2b_3x3')
    copy_basic_conv(tmodel, cmodel.Conv2d_3b_1x1, 'Conv2d_3b_1x1')
    copy_basic_conv(tmodel, cmodel.Conv2d_4a_3x3, 'Conv2d_4a_3x3')

    copy_mixed(tmodel, cmodel.Mixed_5b, 'Mixed_5b')
    copy_mixed(tmodel, cmodel.Mixed_5c, 'Mixed_5c')
    copy_mixed(tmodel, cmodel.Mixed_5d, 'Mixed_5d')
    copy_mixed(tmodel, cmodel.Mixed_6a, 'Mixed_6a')
    copy_mixed(tmodel, cmodel.Mixed_6b, 'Mixed_6b')
    copy_mixed(tmodel, cmodel.Mixed_6c, 'Mixed_6c')
    copy_mixed(tmodel, cmodel.Mixed_6d, 'Mixed_6d')
    copy_mixed(tmodel, cmodel.Mixed_6e, 'Mixed_6e')
    copy_mixed(tmodel, cmodel.Mixed_7a, 'Mixed_7a')
    copy_mixed(tmodel, cmodel.Mixed_7b, 'Mixed_7b')
    copy_mixed(tmodel, cmodel.Mixed_7c, 'Mixed_7c')

    chainer.serializers.save_npz('pretrained_inception_v3', cmodel)


if __name__ == '__main__':
    main()

幸いなことに、出力の誤差はほとんど気にならない程でした。
細かいことを気にせず、とりあえず動けばいいやって場合は使えると思います。

全てのモデルがそのまま転写できるわけではない

同じような方法をResNetで試したところ、convの途中でwidth、heightがずれたり出力に数%の誤差が生じることを確認しました。
PyTorchとChainerではpadding周りの実装が異なるようで、そのあたりが影響している気がします。ちゃんと調べてないですが。