Help us understand the problem. What is going on with this article?

pytorchがとっつきにくい? skorchから始めてみようよ

More than 1 year has passed since last update.

ここで、pytorchをscikit-learnライクに使うことができるようにしたラッパーであるskorchについて紹介します。

20171220123959.png

この記事で記述しているスクリプトは以下に置いております。
https://gist.github.com/AnchorBlues/b51836c96b90b25b35f209ce7ac8f522

TL;DR

  • skorch使うと、

    • skorchを用いて作成したmodelオブジェクトには、scikit-learnのようにfitやpredictなどのメソッドが一通り揃っており、まさにscikit-learn並の手軽さでモデルの学習・予測の処理を行える
    • scikit-learnのAPIの恩恵に預かって、pipeline構築したりGridSearchしたりできる。 例えばこんな感じ。
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import StandardScaler
    from skorch import NeuralNetClassifier
    
    model = NeuralNetClassifier(
        Net,   # pytorchで実装したnn.Module継承クラス
        max_epochs=10,
        lr=0.1
    )
    
    pipe = Pipeline([
        ('scale', StandardScaler()),
        ('net', model),
    ])
    
    pipe.fit(X, y)
    y_proba = pipe.predict_proba(X)
    

はじめに

pytorchの(他のDeep Learningフレームワークと比較した際の)メリット・デメリットを挙げると、だいたい以下のような感じになるのではないでしょうか。

  • メリット
    • define by run。 学習処理内にbreakpointを置いたデバッグが可能。
    • modelもlayerもどっちも同じModuleクラスというエレガントな設計。 複雑なネットワークも割と簡単に記述できる
    • tensorオブジェクトにsqueeze、viewなどのメソッドが充実しているのでコードの可読性が高い
    • gpuのon/offが非常に楽
  • デメリット
    • web上の情報がやや少ない
    • 学習・予測処理を自分で一から書く必要がある。 dataloaderを作成するところを含めて、学習・予測処理を記述するのが面倒くさい

特にデメリットの2つ目が、今回着目するところになります。

例えば

  • モデル(model)
  • 学習データ(x_train, y_train)

が用意できて、いよいよモデルの学習に取り掛かろう、という際に、例えばscikit-learnやkerasだと

model.fit(x_train, y_train)

と1行で書けてしまうのに対して、pytorchの普通のやり方だと、

# datasetオブジェクト作成
x_train_tensor = torch.Tensor(x_train).to(device)
y_train_tensor = torch.Tensor(y_train).to(device)
dataset = torch.utils.data.TensorDataset(x_train_tensor, y_train_tensor)

# dataloaderオブジェクト作成
train_loader = DataLoader(train, batch_size=128, shuffle=True)

# 学習処理
for epoch in range(2):  # epochに関するイテレーション
    for i, data in enumerate(train_loader, 0):  # batchに関するイテレーション
        inputs, labels = data                      # ミニバッチデータの取得
        optimizer.zero_grad()                      # パラメータの勾配を初期化
        outputs = model(inputs)                    # forward
        loss = criterion(outputs, labels)
        loss.backward()                            # backward
        optimizer.step()                           # パラメータの更新

と、datasetオブジェクト作成から数えると優に10行を超える記述が必要になってしまいます。

しかも予測時には更に

model.eval()
with torch.no_grad():
    # (予測処理)

って書かないといけなかったりとかして、とにかく記述が面倒。

もちろん、逆に一から書く設計になっているおかげで、学習・予測処理の合間に好きな処理を挟んだりすることができるという柔軟性があるという側面もあります。
ただネットワークの学習・評価をぱぱっとやりたい時などには、この面倒さがかなり高い心理的なハードルになってしまいます。

そこでskorchの出番です。
skorchを使えば、pytorchで構築したモデルの学習・予測の記述が非常に簡単になります。
pytorchをwrapしているだけなので、もちろんGPUも使えます。

インストール

skorchはpipからインストールできます。
(当然、pytorchは事前にインストールしておく必要があります。)

$ pip install skorch

バージョンは

$ python -c "import skorch; print(skorch.__version__)"   
0.5.0.post0

を使用しました。

ここから、skorchの代表的なクラスである、

  • NeuralNet
  • NeuralNetClassifier

について見ていきます。

NeuralNetの使い方

NeuralNetとは、pytorchで作成したNet(nn.Moduleを継承してネットワークを定義したクラスのことをここではこう呼びます)を引数に取り、scikit-learnライクなmodelオブジェクトを作成するskorchのクラスです。

NeuralNet — skorch 0.5.0.post0 documentation

試しにmnistデータの分類モデルの構築を行ってみましょう。

from keras.datasets import mnist

# mnistデータのロード
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# normalize・channelize, make dtype float32
x_train = np.expand_dims(x_train / 255, 1).astype(np.float32)
x_test = np.expand_dims(x_test / 255, 1).astype(np.float32)

# make dtype int64 for criterion
y_train = y_train.astype(np.int64)

modelオブジェクトの定義

まず、Moduleを継承したNetクラスを定義するところまでは同じ。

import torch
from torch import nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv2_drop = nn.Dropout2d()
        # 1600 = number channels * width * height
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))    # (32, 13, 13)
        x = F.relu(F.max_pool2d(self.conv2_drop(
            self.conv2(x)), 2))   # (64, 5, 5)
        # flatten over channel, height and width = 1600 = 64*5*5
        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=-1)
        return x

続いて、skorchのNeuralNetを用いて、modelオブジェクトを以下のように作成します。
(ここで、epoch数やbatch_size・optimizerなども指定します)

from skorch import NeuralNet

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = NeuralNet(
  CNN,            # ここで、先程定義したNetクラスを引数として与える
  max_epochs=10,
  optimizer=torch.optim.Adam,
  lr=0.001,
  device=device,
  batch_size=128,
  # デフォルトだと入力データの2割が検証に使われる。入力データすべてを学習に使うためには、以下の通りにする
  # train_split=None,
  criterion=nn.NLLLoss    # CNNの最後のactivationがlog_softmaxなので、lossはNLLoss。
)

以上で準備は終わりです。後はscikit-learnと全く同じです。
ここで定義したmodelオブジェクトのfitやpredictといったメッソドを実行すればいいだけです。

学習・予測・評価

学習も予測も1行で記述できてしまいます。

# 学習
model.fit(x_train, y_train)

# 予測
pred = model.prefict(x_test)   # CNNのforwardメソッドの出力(shape:(10000, 10))が得られる。

# 評価
pred = pred.argmax(axis=1)
print("acc:{}".format(accuracy_score(y_test, pred)))

ちなみに学習時には、学習の経過が以下のようにいい感じに出力されます。

スクリーンショット 2019-05-16 20.01.37.png

前回に比べて良くなったエポックに関しては色がつくよう仕様になっているようです。

その他の処理

scikit-learnのmodelオブジェクトに対して行える処理は一通り行うことができます。

例えばjoblibを用いたmodelのsave・loadなど。

from sklearn.externals import joblib
filename = 'model.obj'

# モデルの保存
joblib.dump(model, filename)

# 保存したモデルのロード
model = joblib.load(filename)

NeuralNetの注意点

このskorch.NeuralNetを用いて作成したmodelオブジェクトの扱いにはやや注意が必要です。
普通のscikit-learnで構築した分類モデル(LogisticRegressionなど)の場合は、

pred = model.predict(x_test)

とした時、predはaxis=1で勝手にargmaxを施されているので、pred.shape=(10000, )となります。

一方でskorchの場合は、

pred = model.predict(x_test)

とした場合、predは最後のLinear層->log_softmaxの活性化が施された値であり、pred.shape=(10000, 10)となります。

そのためskorchこのpredに対して直接

accuracy_score(y_test, pred)

の計算を行うことができません。

まあ、普通にaccuracyを求めたいだけならば事前にpredに対してargmax(axis=1)を取ればいいだけなのですが、これで何が困るかというと、例えばscikit-learnのGridSearchCVを用いてaccuracyが最も良くなるパラメータを見つけたいと思った時に、

grid_search = GridSearchCV(model, param_grid, scoring='accuracy')

とすることができないのです。
(scoring='accuracy'とすると、model.predictの出力とyとのaccuracyをとる処理が内部で走るような設計になっているため)

そこでNeuralNetClassifierの出番です。

NeuralNetClassifierの使い方

NeuralNetClassifierとNeuralNetの違い

この2つの違いはわずかで、NeuralNetClassifierを用いると、

  • model.predictにより、forwardメソッドの出力をargmaxした結果が得られる
    • そのためGridSearchCV(model, param_grid, scoring='accuracy')が可能となる
  • model.predict_probaメソッドが使える

などのメリットがあります。

しかし学習時の設定に色々な癖があるので、上述の2つのメリットによる嬉しさがそんなにない場合には、NeuralNetの方を使ったほうが良いでしょう。

一応、NeuralNetClassifierの使い方と注意点・ハマりポイントみたいなのを以下にまとめておきます。

NeuralNetClassifierの注意点

まず注意として、NeuralNetClassifierを用いる場合には、Netオブジェクトのforwardメソッドの最後の活性化関数は必ずF.softmax(dim=-1)にしないといけません。log_softmaxだと、predict_proba時に適切に確率の値を返してくれません。

なお、この制限は2クラス分類であっても同様です。普通は、NNの最後の出力のsizeを1にしてsigmoidを施したくなりますが、NeuralNetClassifierだとそれは機能してくれません。
NeuralNetClassifierを使う場合には、必ずどんなケースでもNNの最後の出力は

  • size:class数
  • actication:F.softmax

にしましょう。

import torch
from torch import nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv2_drop = nn.Dropout2d()
        # 1600 = number channels * width * height
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))    # (32, 13, 13)
        x = F.relu(F.max_pool2d(self.conv2_drop(
            self.conv2(x)), 2))   # (64, 5, 5)
        # flatten over channel, height and width = 1600 = 64*5*5
        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        # 最後の活性化は、必ずF.softmax(dim=-1)に。
        x = F.softmax(x, dim=-1)
        return x

続いて次もハマりポイントなのですが、modelオブジェクトを作成する際のcriterionはnn.NLLLossにします。
nn.CrossEntropyLossにはしません。

model = NeuralNet(
  CNN,
  max_epochs=10,
  optimizer=torch.optim.Adam,
  lr=0.001,
  device=device,
  batch_size=128,
  criterion=nn.NLLLoss    # forwardの最後がsoftmaxだが、ここはnn.NLLoss
)

公式ドキュメントにも、

Note that skorch does not automatically apply any nonlinearities to the outputs (except internally when determining the PyTorch NLLLoss, see below). That means that if you have a classification task, you should make sure that the final output nonlinearity is a softmax. Otherwise, when you call predict_proba(), you won’t get actual probabilities.

とあります(NeuralNet — skorch 0.5.0.post0 documentation)。

なんか気持ち悪いですが、この部分をしっかり守っておけば、ちゃんと学習・予測を実行してくれますし、GridSearchCVでscoringにaccuracyを指定することも可能です。

実際にGridSearchCVと併せてパラメータ探索を行うスクリプトは
skorch/Advanced_Usage.ipynb at master · skorch-dev/skorch
をご覧ください。

まとめ

pytorchの学習・予測処理を簡単に記述することができるパッケージskorchを紹介しました。
skorchを使うと、

  • Define by Runなどのpytorchに元々あったメリットを失うことなく、
  • まさにscikit-learnやkerasのようなお手軽さでモデルの学習・予測を行うことができる
  • また、pipelineやGridSearchCVなどのscikit-learnの便利機能も使える

と、まさに色々なパッケージのいいとこ取りをしたようなパッケージです。

ただ、やはり「各ミニバッチ学習終了後ごとにある処理を走らせたい」などの細かい処理をやろうと思ったら、pytorchそのものだけを使ったほうが良いので、柔軟に使い分けるのが良いかと思います。

skorch.NeuralNetを使ってmnistデータに対しCNNの学習・評価を行ったスクリプトを以下に置きました。

https://gist.github.com/AnchorBlues/b51836c96b90b25b35f209ce7ac8f522

参考URL

AnchorBlues
とある大学院で地球惑星科学を専攻した後、現在はとある民間企業で働いています。応用数学(数理最適化、データマイニング、人工知能、etc...)の研究員になる予定です。
http://anchorblues.hatenablog.com/
ntt-data-msi
数理科学とコンピュータサイエンスの融合!!
http://www.msi.co.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした