LoginSignup
13
17

pytorchとchainerが似ている件。CNNの書き方で比べてみよう

Last updated at Posted at 2018-03-11

以前からずっとchainerを使っていたが、最近pytorchを試してみました。

この2つは驚くほど似ていると思うので、ここでコードを並べて比較してみようと思います。

実装するモデルの概念

  • クラスの中でニューラルネットワークの構造を定義する
  • 分類問題に使うためのCNN
  • 畳み込み層2つ、線形層2つ
  • ミニバッチで学習する
  • 学習するためのメソッドはクラスの中に定義する
  • 経過した時間を常に計る
  • 毎回検証データに対する正確度を計算して出力する

pytorch

import torch
import time

relu = torch.nn.ReLU()
entropy = torch.nn.CrossEntropyLoss()
pool = torch.nn.MaxPool2d(2)

class PytorchCnn(torch.nn.Module):
    def __init__(self,gakushuuritsu):
        super(PytorchCnn,self).__init__()
        self.c1 = torch.nn.Conv2d(1,16,5,1,2)
        self.c2 = torch.nn.Conv2d(16,32,5,1,2)
        self.l1 = torch.nn.Linear(1568,16)
        self.l2 = torch.nn.Linear(16,10)
        self.opt = torch.optim.Adam(self.parameters(),lr=gakushuuritsu)
    
    def forward(self,x):
        x = self.c1(x)
        x = relu(x)
        x = pool(x)
        x = self.c2(x)
        x = relu(x)
        x = pool(x)
        x = x.view(x.size()[0],-1)
        x = self.l1(x)
        x = relu(x)
        x = self.l2(x)
        return x

    def gakushuu(self,X_kunren,y_kunren,X_kenshou,y_kenshou,kurikaeshi,n_batch):
        n_kunren = len(y_kunren)
        X_kunren = torch.FloatTensor(X_kunren.reshape(-1,1,28,28))
        y_kunren = torch.LongTensor(y_kunren)
        X_kenshou = torch.FloatTensor(X_kenshou.reshape(-1,1,28,28))
        y_kenshou = torch.LongTensor(y_kenshou)
        
        print(u'==pytorch学習開始==')
        t_kaishi = time.time()
        for j in range(kurikaeshi):
            batch = torch.randperm(n_kunren)
            for i in range(0,n_kunren,n_batch):
                Xn = X_kunren[batch[i:i+n_batch]]
                yn = y_kunren[batch[i:i+n_batch]]
                loss = entropy(self(Xn),yn)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
            seikaku = (self(X_kenshou).argmax(1)==y_kenshou).type(torch.FloatTensor).mean().item()
            print(u'%d回目 正確度%.4f もう%.3f分かかった'%(j+1,seikaku,(time.time()-t_kaishi)/60))

chainer

import chainer
import numpy as np

class ChainerCnn(chainer.Chain):
    def __init__(self,gakushuuritsu):
        super(ChainerCnn,self).__init__()
        with self.init_scope():
            self.c1 = chainer.links.Convolution2D(1,16,5,1,2)
            self.c2 = chainer.links.Convolution2D(16,32,5,1,2)
            self.l1 = chainer.links.Linear(1568,16)
            self.l2 = chainer.links.Linear(16,10)
        self.opt = chainer.optimizers.Adam(gakushuuritsu)
        self.opt.setup(self)
    
    def __call__(self,x):
        x = self.c1(x)
        x = chainer.functions.relu(x)
        x = chainer.functions.max_pooling_2d(x,2)
        x = self.c2(x)
        x = chainer.functions.relu(x)
        x = chainer.functions.max_pooling_2d(x,2)
        x = self.l1(x)
        x = chainer.functions.relu(x)
        x = self.l2(x)
        return x
    
    def gakushuu(self,X_kunren,y_kunren,X_kenshou,y_kenshou,kurikaeshi,n_batch):
        n_kunren = len(y_kunren)
        X_kunren = X_kunren.reshape(-1,1,28,28).astype(np.float32)
        y_kunren = y_kunren.astype(np.int32)
        X_kenshou = X_kenshou.reshape(-1,1,28,28).astype(np.float32)
        y_kenshou = y_kenshou.astype(np.int32)
        
        print(u'==chainer学習開始==')
        t_kaishi = time.time()
        for j in range(kurikaeshi):
            batch = np.random.permutation(n_kunren)
            for i in range(0,n_kunren,n_batch):
                Xn = X_kunren[batch[i:i+n_batch]]
                yn = y_kunren[batch[i:i+n_batch]]
                loss = chainer.functions.softmax_cross_entropy(self(Xn),yn)
                self.cleargrads()
                loss.backward()
                self.opt.update()
            seikaku = (chainer.functions.argmax(self(X_kenshou),1).data==y_kenshou).mean()
            print(u'%d回目 正確度%.4f もう%.3f分かかった'%(j+1,seikaku,(time.time()-t_kaishi)/60))

MNIST数字に使う

from sklearn import datasets
from sklearn.model_selection import train_test_split

gakushuuritsu = 0.001
kurikaeshi = 10
n_batch = 64

# データの準備
mnist = datasets.fetch_mldata('MNIST original')
X_kunren,X_kenshou,y_kunren,y_kenshou = train_test_split(mnist.data/255.,mnist.target)

# pytorch
ptmodel = PytorchCnn(gakushuuritsu)
ptmodel.gakushuu(X_kunren,y_kunren,X_kenshou,y_kenshou,kurikaeshi,n_batch)

# chainer
chnmodel = ChainerCnn(gakushuuritsu)
chnmodel.gakushuu(X_kunren,y_kunren,X_kenshou,y_kenshou,kurikaeshi,n_batch)

結果

==pytorch学習開始==
1回目 正確度0.8943 もう0.503分かかった
2回目 正確度0.9437 もう0.956分かかった
3回目 正確度0.9519 もう1.356分かかった
4回目 正確度0.9581 もう1.842分かかった
5回目 正確度0.9629 もう2.311分かかった
6回目 正確度0.9686 もう2.812分かかった
7回目 正確度0.9716 もう3.266分かかった
8回目 正確度0.9719 もう3.697分かかった
9回目 正確度0.9694 もう4.163分かかった
10回目 正確度0.9705 もう4.625分かかった
==chainer学習開始==
1回目 正確度0.9293 もう1.366分かかった
2回目 正確度0.9534 もう2.491分かかった
3回目 正確度0.9633 もう3.688分かかった
4回目 正確度0.9665 もう4.771分かかった
5回目 正確度0.9654 もう6.009分かかった
6回目 正確度0.9720 もう7.294分かかった
7回目 正確度0.9658 もう8.566分かかった
8回目 正確度0.9693 もう10.206分かかった
9回目 正確度0.9730 もう11.475分かかった
10回目 正確度0.9765 もう12.617分かかった

主な違うところ

  • pytorchではforwardを定義するものは自動的に呼び出す時に使われる。chainerでは直接__call__を定義する。
  • pytorchで使うためにArrayからTensorへ、そしてTensorからVariableへ、2度変換しなければならないので、ちょっと面倒。chainerではそのままArrayを使ってもいい。自動的にVariableに変換される。ただし浮動小数点数はfloat32に変換する必要がある。 [2018/8/19] pytorchはv0.4に更新された後、もうVariableへ変換する必要がなくなって、直接Tensorを使うことができるので便利になった。
  • pytorchではargmaxみたいなものがなく、代わりにtorch.maxを使う時にmaxとargmaxは一緒に戻される。
  • pytorchではConvolution2DからLinearへ向かう時、xを変形する段階を自分で書かなければならないが、chainerでは自動的に変形される。

速度についてですが、明らかに違って、pytorchの方が2~3倍ほど速い。データが大きいほど差が著しくなるようです。

似ているように見えるのに速度はそんなに違うなんて意外です。

pytorchでもっと書きやすく

それに、実はpytorchはもっと便利な書き方があります。それはtorch.nn.Sequentialを使うことです。こんな風に書き直せます。

flat = torch.nn.Module()
flat.forward = lambda x:x.view(x.size()[0],-1)

class PytorchCnn(torch.nn.Sequential):
    def __init__(self,gakushuuritsu):
        super(PytorchCnn,self).__init__(
            torch.nn.Conv2d(1,16,5,1,2),
            relu,
            pool,
            torch.nn.Conv2d(16,32,5,1,2),
            relu,
            pool,
            flat,
            torch.nn.Linear(1568,16),
            relu,
            torch.nn.Linear(16,10))
        self.opt = torch.optim.Adam(self.parameters(),lr=gakushuuritsu)

これを使ったらforwardを定義する必要なく、自動的に順番通り進むので便利。

ただし、変形の段階は前もって変形するためのクラスを定義しておかなければならない。

torch.nn.Sequentialに入れるものはtorch.nn.Moduleクラスかサブクラスのオブジェクトでなければならない。

終わりに

pytorchとchainerはすごく似ています。chainerを使っていたかたはpytorchに乗り換えやすいと思います。

その他にpytorchとchainerを比較する記事はこちらも参考に

13
17
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
17