LoginSignup
11
2

More than 5 years have passed since last update.

Chainer版InceptionV3の学習済みモデルを作りました

Last updated at Posted at 2017-12-18

はじめに

諸事情によりInceptionV3の学習済みモデルが必要になりました。
TensorFlowとPyTorchは重みが提供されているのですが、私が普段使いしているChainerは軽く探して見つけることができませんでした。
InceptionV3をどうしてもChainerで実装したかったのですが、私はGPUを1024基持っていないので、ImageNetを15分で学習することを諦めてPyTorchの学習済みモデルの重みをChainerのモデルに転写することにしました。
onnx-chainer??聞いたことないですね。

PyTorchの学習済みモデルについて

PyTorchのInceptionV3学習済みモデルはtorchvisionに最初から備わっています。展開するとわかるのですが、PyTorchでは学習済みモデルをOrderedDictとして固めています。

>>> import torch.utils.model_zoo as model_zoo
>>> tmodel = model_zoo.load_url('https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth')
>>> type(tmodel)
<class 'collections.OrderedDict'>
>>>                                                                                                    

中身はtorch.FloatTensorで、手前のlayerから順に格納されているようです。.numpy()とすれば中身をnumpy.ndarrayに変換できます。

>>> tmodel['fc.weight'].numpy()

Chainerモデルについて

一方、Chainerではモデルの重みをchainer.Linkのパラメータとして定義しています。例えば、convの重みは以下のように取得できます。

>>> conv.W.data
>>> conv.b.data

モデル定義の際にW,bに値を渡してやれば、任意の値でモデルを初期化できます。

やること

つまり、PyTorchの学習済みモデルを展開してnumpy.ndarrayに変換し、chainer.Linkの対応する変数に渡してやれば、学習済みchainer.Chainが手に入ります。

ChainerでInceptionV3を実装

重みを転写するにあたり、簡単のためにlayerの構造や名前をPyTorch実装と完全に揃えておきます。

inception_v3.py
import chainer
from chainer import functions as F
from chainer import links as L


class InceptionV3(chainer.Chain):
    def __init__(self):
        super(InceptionV3, self).__init__()
        with self.init_scope():
            self.Conv2d_1a_3x3 = BasicConv2d(3, 32, ksize=3, stride=2)
            self.Conv2d_2a_3x3 = BasicConv2d(32, 32, ksize=3)
            self.Conv2d_2b_3x3 = BasicConv2d(32, 64, ksize=3, pad=1)
            self.Conv2d_3b_1x1 = BasicConv2d(64, 80, ksize=1)
            self.Conv2d_4a_3x3 = BasicConv2d(80, 192, ksize=3)
            self.Mixed_5b = InceptionA(192, pool_features=32)
            self.Mixed_5c = InceptionA(256, pool_features=64)
            self.Mixed_5d = InceptionA(288, pool_features=64)
            self.Mixed_6a = InceptionB(288)
            self.Mixed_6b = InceptionC(768, channels_7x7=128)
            self.Mixed_6c = InceptionC(768, channels_7x7=160)
            self.Mixed_6d = InceptionC(768, channels_7x7=160)
            self.Mixed_6e = InceptionC(768, channels_7x7=192)
            self.Mixed_7a = InceptionD(768)
            self.Mixed_7b = InceptionE(1280)
            self.Mixed_7c = InceptionE(2048)

    def __call__(self, x):
        h = x
        # 299 x 299 x 3
        h = self.Conv2d_1a_3x3(h)
        # 149 x 149 x 32
        h = self.Conv2d_2a_3x3(h)
        # 147 x 147 x 32
        h = self.Conv2d_2b_3x3(h)
        # 147 x 147 x 64
        h = F.max_pooling_2d(h, ksize=3, stride=2)
        # 73 x 73 x 64
        h = self.Conv2d_3b_1x1(h)
        # 73 x 73 x 80
        h = self.Conv2d_4a_3x3(h)
        # 71 x 71 x 192
        h = F.max_pooling_2d(h, ksize=3, stride=2)
        # 35 x 35 x 192
        h = self.Mixed_5b(h)
        # 35 x 35 x 256
        h = self.Mixed_5c(h)
        # 35 x 35 x 288
        h = self.Mixed_5d(h)
        # 35 x 35 x 288
        h = self.Mixed_6a(h)
        # 17 x 17 x 768
        h = self.Mixed_6b(h)
        # 17 x 17 x 768
        h = self.Mixed_6c(h)
        # 17 x 17 x 768
        h = self.Mixed_6d(h)
        # 17 x 17 x 768
        h = self.Mixed_6e(h)
        # 17 x 17 x 768
        h = self.Mixed_7a(h)
        # 8 x 8 x 1280
        h = self.Mixed_7b(h)
        # 8 x 8 x 2048
        h = self.Mixed_7c(h)
        # 8 x 8 x 2048
        h = F.average_pooling_2d(h, ksize=8)
        # 1 x 1 x 2048
        h = F.dropout(h)
        return h


class InceptionA(chainer.Chain):
    def __init__(self, in_channels, pool_features):
        super(InceptionA, self).__init__()
        with self.init_scope():
            self.branch1x1 = BasicConv2d(in_channels, 64, ksize=1)
            self.branch5x5_1 = BasicConv2d(in_channels, 48, ksize=1)
            self.branch5x5_2 = BasicConv2d(48, 64, ksize=5, pad=2)
            self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, ksize=1)
            self.branch3x3dbl_2 = BasicConv2d(64, 96, ksize=3, pad=1)
            self.branch3x3dbl_3 = BasicConv2d(96, 96, ksize=3, pad=1)
            self.branch_pool = BasicConv2d(in_channels, pool_features, ksize=1)

    def __call__(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.average_pooling_2d(x, ksize=3, stride=1, pad=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = (branch1x1, branch5x5, branch3x3dbl, branch_pool)
        return F.concat(outputs, axis=1)


class InceptionB(chainer.Chain):
    def __init__(self, in_channels):
        super(InceptionB, self).__init__()
        with self.init_scope():
            self.branch3x3 = BasicConv2d(in_channels, 384, ksize=3, stride=2)

            self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, ksize=1)
            self.branch3x3dbl_2 = BasicConv2d(64, 96, ksize=3, pad=1)
            self.branch3x3dbl_3 = BasicConv2d(96, 96, ksize=3, stride=2)

    def __call__(self, x):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pooling_2d(x, ksize=3, stride=2)

        outputs = (branch3x3, branch3x3dbl, branch_pool)
        return F.concat(outputs, axis=1)


class InceptionC(chainer.Chain):
    def __init__(self, in_channels, channels_7x7):
        super(InceptionC, self).__init__()
        with self.init_scope():
            self.branch1x1 = BasicConv2d(in_channels, 192, ksize=1)

            c7 = channels_7x7
            self.branch7x7_1 = BasicConv2d(in_channels, c7, ksize=1)
            self.branch7x7_2 = BasicConv2d(c7, c7, ksize=(1, 7), pad=(0, 3))
            self.branch7x7_3 = BasicConv2d(c7, 192, ksize=(7, 1), pad=(3, 0))

            self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, ksize=1)
            self.branch7x7dbl_2 = BasicConv2d(c7, c7, ksize=(7, 1),
                    pad=(3, 0))
            self.branch7x7dbl_3 = BasicConv2d(c7, c7, ksize=(1, 7),
                    pad=(0, 3))
            self.branch7x7dbl_4 = BasicConv2d(c7, c7, ksize=(7, 1),
                    pad=(3, 0))
            self.branch7x7dbl_5 = BasicConv2d(c7, 192, ksize=(1, 7),
                    pad=(0, 3))

            self.branch_pool = BasicConv2d(in_channels, 192, ksize=1)

    def __call__(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.average_pooling_2d(x, ksize=3, stride=1, pad=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = (branch1x1, branch7x7, branch7x7dbl, branch_pool)
        return F.concat(outputs, axis=1)


class InceptionD(chainer.Chain):
    def __init__(self, in_channels):
        super(InceptionD, self).__init__()
        with self.init_scope():
            self.branch3x3_1 = BasicConv2d(in_channels, 192, ksize=1)
            self.branch3x3_2 = BasicConv2d(192, 320, ksize=3, stride=2)

            self.branch7x7x3_1 = BasicConv2d(in_channels, 192, ksize=1)
            self.branch7x7x3_2 = BasicConv2d(192, 192, ksize=(1, 7), pad=(0, 3))
            self.branch7x7x3_3 = BasicConv2d(192, 192, ksize=(7, 1), pad=(3, 0))
            self.branch7x7x3_4 = BasicConv2d(192, 192, ksize=3, stride=2)

    def __call__(self, x):
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pooling_2d(x, ksize=3, stride=2)
        outputs = (branch3x3, branch7x7x3, branch_pool)
        return F.concat(outputs, axis=1)


class InceptionE(chainer.Chain):
    def __init__(self, in_channels):
        super(InceptionE, self).__init__()
        with self.init_scope():
            self.branch1x1 = BasicConv2d(in_channels, 320, ksize=1)

            self.branch3x3_1 = BasicConv2d(in_channels, 384, ksize=1)
            self.branch3x3_2a = BasicConv2d(384, 384, ksize=(1, 3), pad=(0, 1))
            self.branch3x3_2b = BasicConv2d(384, 384, ksize=(3, 1), pad=(1, 0))

            self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, ksize=1)
            self.branch3x3dbl_2 = BasicConv2d(448, 384, ksize=3, pad=1)
            self.branch3x3dbl_3a = BasicConv2d(384, 384, ksize=(1, 3), pad=(0, 1))
            self.branch3x3dbl_3b = BasicConv2d(384, 384, ksize=(3, 1), pad=(1, 0))

            self.branch_pool = BasicConv2d(in_channels, 192, ksize=1)

    def __call__(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
                self.branch3x3_2a(branch3x3),
                self.branch3x3_2b(branch3x3),
                ]
        branch3x3 = F.concat(branch3x3, axis=1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
                self.branch3x3dbl_3a(branch3x3dbl),
                self.branch3x3dbl_3b(branch3x3dbl),
                ]
        branch3x3dbl = F.concat(branch3x3dbl, axis=1)

        branch_pool = F.average_pooling_2d(x, ksize=3, stride=1, pad=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = (branch1x1, branch3x3, branch3x3dbl, branch_pool)
        return F.concat(outputs, axis=1)


class BasicConv2d(chainer.Chain):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        with self.init_scope():
            self.conv = L.Convolution2D(in_channels, out_channels, nobias=True,
                    **kwargs)
            self.bn = L.BatchNormalization(out_channels, eps=0.001)

    def __call__(self, x):
        h = self.conv(x)
        h = self.bn(h)
        return F.relu(h)

ちなみに、側鎖は不要だったので定義していません。

転写する

後は1層ずつ値をコピーするだけです。以下のスクリプトを実行しました。
一度しか使わないはずのコードなので、べた書きです。

copy_inception.py
import chainer
from chainer import Variable
from chainermodel.inception_v3 import InceptionV3

import torch.utils.model_zoo as model_zoo

CONV = 'Convolution2D'
BN   = 'BatchNormalization'
BASIC_CONV = 'BasicConv2d'

def copy_array(t, c):
    assert t.shape == c.shape
    c.data = t 

def copy_conv(t, c, name):
    assert c.__class__.__name__ == CONV
    copy_array(t[name + '.conv.weight'].numpy(), c.W)

def copy_bn(t, c, name):
    assert c.__class__.__name__ == BN
    copy_array(t[name + '.bn.weight'].numpy()      , c.gamma)
    copy_array(t[name + '.bn.bias'].numpy()        , c.beta)
    copy_array(t[name + '.bn.running_mean'].numpy(), c.avg_mean)
    copy_array(t[name + '.bn.running_var'].numpy() , c.avg_var)

def copy_basic_conv(t, c, name):
    assert c.__class__.__name__ == BASIC_CONV
    copy_conv(t, c.conv, name)
    copy_bn(t, c.bn, name)

def copy_mixed(t, c, name):
    for childname in c._children:
        child = c[childname]
        assert child.__class__.__name__ == BASIC_CONV
        copy_basic_conv(t, child, name + '.' + childname)

def main():
    tmodel = model_zoo.load_url('https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth')
    cmodel = InceptionV3()

    copy_basic_conv(tmodel, cmodel.Conv2d_1a_3x3, 'Conv2d_1a_3x3')
    copy_basic_conv(tmodel, cmodel.Conv2d_2a_3x3, 'Conv2d_2a_3x3')
    copy_basic_conv(tmodel, cmodel.Conv2d_2b_3x3, 'Conv2d_2b_3x3')
    copy_basic_conv(tmodel, cmodel.Conv2d_3b_1x1, 'Conv2d_3b_1x1')
    copy_basic_conv(tmodel, cmodel.Conv2d_4a_3x3, 'Conv2d_4a_3x3')

    copy_mixed(tmodel, cmodel.Mixed_5b, 'Mixed_5b')
    copy_mixed(tmodel, cmodel.Mixed_5c, 'Mixed_5c')
    copy_mixed(tmodel, cmodel.Mixed_5d, 'Mixed_5d')
    copy_mixed(tmodel, cmodel.Mixed_6a, 'Mixed_6a')
    copy_mixed(tmodel, cmodel.Mixed_6b, 'Mixed_6b')
    copy_mixed(tmodel, cmodel.Mixed_6c, 'Mixed_6c')
    copy_mixed(tmodel, cmodel.Mixed_6d, 'Mixed_6d')
    copy_mixed(tmodel, cmodel.Mixed_6e, 'Mixed_6e')
    copy_mixed(tmodel, cmodel.Mixed_7a, 'Mixed_7a')
    copy_mixed(tmodel, cmodel.Mixed_7b, 'Mixed_7b')
    copy_mixed(tmodel, cmodel.Mixed_7c, 'Mixed_7c')

    chainer.serializers.save_npz('pretrained_inception_v3', cmodel)


if __name__ == '__main__':
    main()

幸いなことに、出力の誤差はほとんど気にならない程でした。
細かいことを気にせず、とりあえず動けばいいやって場合は使えると思います。

全てのモデルがそのまま転写できるわけではない

同じような方法をResNetで試したところ、convの途中でwidth、heightがずれたり出力に数%の誤差が生じることを確認しました。
PyTorchとChainerではpadding周りの実装が異なるようで、そのあたりが影響している気がします。ちゃんと調べてないですが。

11
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
2