はじめに
諸事情によりInceptionV3の学習済みモデルが必要になりました。
TensorFlowとPyTorchは重みが提供されているのですが、私が普段使いしているChainerは軽く探して見つけることができませんでした。
InceptionV3をどうしてもChainerで実装したかったのですが、私はGPUを1024基持っていないので、ImageNetを15分で学習することを諦めてPyTorchの学習済みモデルの重みをChainerのモデルに転写することにしました。
onnx-chainer??聞いたことないですね。
PyTorchの学習済みモデルについて
PyTorchのInceptionV3学習済みモデルはtorchvisionに最初から備わっています。展開するとわかるのですが、PyTorchでは学習済みモデルをOrderedDict
として固めています。
>>> import torch.utils.model_zoo as model_zoo
>>> tmodel = model_zoo.load_url('https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth')
>>> type(tmodel)
<class 'collections.OrderedDict'>
>>>
中身はtorch.FloatTensor
で、手前のlayerから順に格納されているようです。.numpy()
とすれば中身をnumpy.ndarray
に変換できます。
>>> tmodel['fc.weight'].numpy()
Chainerモデルについて
一方、Chainerではモデルの重みをchainer.Link
のパラメータとして定義しています。例えば、convの重みは以下のように取得できます。
>>> conv.W.data
>>> conv.b.data
モデル定義の際にW,b
に値を渡してやれば、任意の値でモデルを初期化できます。
やること
つまり、PyTorchの学習済みモデルを展開してnumpy.ndarray
に変換し、chainer.Link
の対応する変数に渡してやれば、学習済みchainer.Chain
が手に入ります。
ChainerでInceptionV3を実装
重みを転写するにあたり、簡単のためにlayerの構造や名前をPyTorch実装と完全に揃えておきます。
import chainer
from chainer import functions as F
from chainer import links as L
class InceptionV3(chainer.Chain):
def __init__(self):
super(InceptionV3, self).__init__()
with self.init_scope():
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, ksize=3, stride=2)
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, ksize=3)
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, ksize=3, pad=1)
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, ksize=1)
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, ksize=3)
self.Mixed_5b = InceptionA(192, pool_features=32)
self.Mixed_5c = InceptionA(256, pool_features=64)
self.Mixed_5d = InceptionA(288, pool_features=64)
self.Mixed_6a = InceptionB(288)
self.Mixed_6b = InceptionC(768, channels_7x7=128)
self.Mixed_6c = InceptionC(768, channels_7x7=160)
self.Mixed_6d = InceptionC(768, channels_7x7=160)
self.Mixed_6e = InceptionC(768, channels_7x7=192)
self.Mixed_7a = InceptionD(768)
self.Mixed_7b = InceptionE(1280)
self.Mixed_7c = InceptionE(2048)
def __call__(self, x):
h = x
# 299 x 299 x 3
h = self.Conv2d_1a_3x3(h)
# 149 x 149 x 32
h = self.Conv2d_2a_3x3(h)
# 147 x 147 x 32
h = self.Conv2d_2b_3x3(h)
# 147 x 147 x 64
h = F.max_pooling_2d(h, ksize=3, stride=2)
# 73 x 73 x 64
h = self.Conv2d_3b_1x1(h)
# 73 x 73 x 80
h = self.Conv2d_4a_3x3(h)
# 71 x 71 x 192
h = F.max_pooling_2d(h, ksize=3, stride=2)
# 35 x 35 x 192
h = self.Mixed_5b(h)
# 35 x 35 x 256
h = self.Mixed_5c(h)
# 35 x 35 x 288
h = self.Mixed_5d(h)
# 35 x 35 x 288
h = self.Mixed_6a(h)
# 17 x 17 x 768
h = self.Mixed_6b(h)
# 17 x 17 x 768
h = self.Mixed_6c(h)
# 17 x 17 x 768
h = self.Mixed_6d(h)
# 17 x 17 x 768
h = self.Mixed_6e(h)
# 17 x 17 x 768
h = self.Mixed_7a(h)
# 8 x 8 x 1280
h = self.Mixed_7b(h)
# 8 x 8 x 2048
h = self.Mixed_7c(h)
# 8 x 8 x 2048
h = F.average_pooling_2d(h, ksize=8)
# 1 x 1 x 2048
h = F.dropout(h)
return h
class InceptionA(chainer.Chain):
def __init__(self, in_channels, pool_features):
super(InceptionA, self).__init__()
with self.init_scope():
self.branch1x1 = BasicConv2d(in_channels, 64, ksize=1)
self.branch5x5_1 = BasicConv2d(in_channels, 48, ksize=1)
self.branch5x5_2 = BasicConv2d(48, 64, ksize=5, pad=2)
self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, ksize=1)
self.branch3x3dbl_2 = BasicConv2d(64, 96, ksize=3, pad=1)
self.branch3x3dbl_3 = BasicConv2d(96, 96, ksize=3, pad=1)
self.branch_pool = BasicConv2d(in_channels, pool_features, ksize=1)
def __call__(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.average_pooling_2d(x, ksize=3, stride=1, pad=1)
branch_pool = self.branch_pool(branch_pool)
outputs = (branch1x1, branch5x5, branch3x3dbl, branch_pool)
return F.concat(outputs, axis=1)
class InceptionB(chainer.Chain):
def __init__(self, in_channels):
super(InceptionB, self).__init__()
with self.init_scope():
self.branch3x3 = BasicConv2d(in_channels, 384, ksize=3, stride=2)
self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, ksize=1)
self.branch3x3dbl_2 = BasicConv2d(64, 96, ksize=3, pad=1)
self.branch3x3dbl_3 = BasicConv2d(96, 96, ksize=3, stride=2)
def __call__(self, x):
branch3x3 = self.branch3x3(x)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.max_pooling_2d(x, ksize=3, stride=2)
outputs = (branch3x3, branch3x3dbl, branch_pool)
return F.concat(outputs, axis=1)
class InceptionC(chainer.Chain):
def __init__(self, in_channels, channels_7x7):
super(InceptionC, self).__init__()
with self.init_scope():
self.branch1x1 = BasicConv2d(in_channels, 192, ksize=1)
c7 = channels_7x7
self.branch7x7_1 = BasicConv2d(in_channels, c7, ksize=1)
self.branch7x7_2 = BasicConv2d(c7, c7, ksize=(1, 7), pad=(0, 3))
self.branch7x7_3 = BasicConv2d(c7, 192, ksize=(7, 1), pad=(3, 0))
self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, ksize=1)
self.branch7x7dbl_2 = BasicConv2d(c7, c7, ksize=(7, 1),
pad=(3, 0))
self.branch7x7dbl_3 = BasicConv2d(c7, c7, ksize=(1, 7),
pad=(0, 3))
self.branch7x7dbl_4 = BasicConv2d(c7, c7, ksize=(7, 1),
pad=(3, 0))
self.branch7x7dbl_5 = BasicConv2d(c7, 192, ksize=(1, 7),
pad=(0, 3))
self.branch_pool = BasicConv2d(in_channels, 192, ksize=1)
def __call__(self, x):
branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)
branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
branch_pool = F.average_pooling_2d(x, ksize=3, stride=1, pad=1)
branch_pool = self.branch_pool(branch_pool)
outputs = (branch1x1, branch7x7, branch7x7dbl, branch_pool)
return F.concat(outputs, axis=1)
class InceptionD(chainer.Chain):
def __init__(self, in_channels):
super(InceptionD, self).__init__()
with self.init_scope():
self.branch3x3_1 = BasicConv2d(in_channels, 192, ksize=1)
self.branch3x3_2 = BasicConv2d(192, 320, ksize=3, stride=2)
self.branch7x7x3_1 = BasicConv2d(in_channels, 192, ksize=1)
self.branch7x7x3_2 = BasicConv2d(192, 192, ksize=(1, 7), pad=(0, 3))
self.branch7x7x3_3 = BasicConv2d(192, 192, ksize=(7, 1), pad=(3, 0))
self.branch7x7x3_4 = BasicConv2d(192, 192, ksize=3, stride=2)
def __call__(self, x):
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch7x7x3 = self.branch7x7x3_1(x)
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
branch_pool = F.max_pooling_2d(x, ksize=3, stride=2)
outputs = (branch3x3, branch7x7x3, branch_pool)
return F.concat(outputs, axis=1)
class InceptionE(chainer.Chain):
def __init__(self, in_channels):
super(InceptionE, self).__init__()
with self.init_scope():
self.branch1x1 = BasicConv2d(in_channels, 320, ksize=1)
self.branch3x3_1 = BasicConv2d(in_channels, 384, ksize=1)
self.branch3x3_2a = BasicConv2d(384, 384, ksize=(1, 3), pad=(0, 1))
self.branch3x3_2b = BasicConv2d(384, 384, ksize=(3, 1), pad=(1, 0))
self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, ksize=1)
self.branch3x3dbl_2 = BasicConv2d(448, 384, ksize=3, pad=1)
self.branch3x3dbl_3a = BasicConv2d(384, 384, ksize=(1, 3), pad=(0, 1))
self.branch3x3dbl_3b = BasicConv2d(384, 384, ksize=(3, 1), pad=(1, 0))
self.branch_pool = BasicConv2d(in_channels, 192, ksize=1)
def __call__(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = F.concat(branch3x3, axis=1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = F.concat(branch3x3dbl, axis=1)
branch_pool = F.average_pooling_2d(x, ksize=3, stride=1, pad=1)
branch_pool = self.branch_pool(branch_pool)
outputs = (branch1x1, branch3x3, branch3x3dbl, branch_pool)
return F.concat(outputs, axis=1)
class BasicConv2d(chainer.Chain):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
with self.init_scope():
self.conv = L.Convolution2D(in_channels, out_channels, nobias=True,
**kwargs)
self.bn = L.BatchNormalization(out_channels, eps=0.001)
def __call__(self, x):
h = self.conv(x)
h = self.bn(h)
return F.relu(h)
ちなみに、側鎖は不要だったので定義していません。
転写する
後は1層ずつ値をコピーするだけです。以下のスクリプトを実行しました。
一度しか使わないはずのコードなので、べた書きです。
import chainer
from chainer import Variable
from chainermodel.inception_v3 import InceptionV3
import torch.utils.model_zoo as model_zoo
CONV = 'Convolution2D'
BN = 'BatchNormalization'
BASIC_CONV = 'BasicConv2d'
def copy_array(t, c):
assert t.shape == c.shape
c.data = t
def copy_conv(t, c, name):
assert c.__class__.__name__ == CONV
copy_array(t[name + '.conv.weight'].numpy(), c.W)
def copy_bn(t, c, name):
assert c.__class__.__name__ == BN
copy_array(t[name + '.bn.weight'].numpy() , c.gamma)
copy_array(t[name + '.bn.bias'].numpy() , c.beta)
copy_array(t[name + '.bn.running_mean'].numpy(), c.avg_mean)
copy_array(t[name + '.bn.running_var'].numpy() , c.avg_var)
def copy_basic_conv(t, c, name):
assert c.__class__.__name__ == BASIC_CONV
copy_conv(t, c.conv, name)
copy_bn(t, c.bn, name)
def copy_mixed(t, c, name):
for childname in c._children:
child = c[childname]
assert child.__class__.__name__ == BASIC_CONV
copy_basic_conv(t, child, name + '.' + childname)
def main():
tmodel = model_zoo.load_url('https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth')
cmodel = InceptionV3()
copy_basic_conv(tmodel, cmodel.Conv2d_1a_3x3, 'Conv2d_1a_3x3')
copy_basic_conv(tmodel, cmodel.Conv2d_2a_3x3, 'Conv2d_2a_3x3')
copy_basic_conv(tmodel, cmodel.Conv2d_2b_3x3, 'Conv2d_2b_3x3')
copy_basic_conv(tmodel, cmodel.Conv2d_3b_1x1, 'Conv2d_3b_1x1')
copy_basic_conv(tmodel, cmodel.Conv2d_4a_3x3, 'Conv2d_4a_3x3')
copy_mixed(tmodel, cmodel.Mixed_5b, 'Mixed_5b')
copy_mixed(tmodel, cmodel.Mixed_5c, 'Mixed_5c')
copy_mixed(tmodel, cmodel.Mixed_5d, 'Mixed_5d')
copy_mixed(tmodel, cmodel.Mixed_6a, 'Mixed_6a')
copy_mixed(tmodel, cmodel.Mixed_6b, 'Mixed_6b')
copy_mixed(tmodel, cmodel.Mixed_6c, 'Mixed_6c')
copy_mixed(tmodel, cmodel.Mixed_6d, 'Mixed_6d')
copy_mixed(tmodel, cmodel.Mixed_6e, 'Mixed_6e')
copy_mixed(tmodel, cmodel.Mixed_7a, 'Mixed_7a')
copy_mixed(tmodel, cmodel.Mixed_7b, 'Mixed_7b')
copy_mixed(tmodel, cmodel.Mixed_7c, 'Mixed_7c')
chainer.serializers.save_npz('pretrained_inception_v3', cmodel)
if __name__ == '__main__':
main()
幸いなことに、出力の誤差はほとんど気にならない程でした。
細かいことを気にせず、とりあえず動けばいいやって場合は使えると思います。
全てのモデルがそのまま転写できるわけではない
同じような方法をResNetで試したところ、convの途中でwidth、heightがずれたり出力に数%の誤差が生じることを確認しました。
PyTorchとChainerではpadding周りの実装が異なるようで、そのあたりが影響している気がします。ちゃんと調べてないですが。