Help us understand the problem. What is going on with this article?

速習 pytorch-lightning: 今すぐ機械学習の実験をしたいそこのキミへ

More than 1 year has passed since last update.

概要

pytorch lightningは、PyTorchでの開発スピードを爆速にしてくれるライブラリです。今回はこのライブラリを使ってサクッとCNNを実装していきます。

コード

model.py
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as ptl


class CNN(ptl.LightningModule):
    # モデルの定義(PyTorchと一緒)
    def __init__(self):
        super(CNN, self).__init__()
        self.c1 = nn.Sequential(
            nn.Conv2d(1, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.c2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.linear = nn.Linear(12544, 10)
    # フィードフォワードの計算の定義(PyTorchと一緒)
    def forward(self, image):

        h = self.c1(image)
        h = self.c2(h)

        batch_size = h.size(0)
        h = self.linear(h.view(batch_size, -1))
        return F.log_softmax(h, dim=1)

    # accuacyを計算するためのプライベート関数
    def _accuracy(self, preds, labels):
        _, preds = torch.max(preds, dim=1)
        return (preds == labels).sum().float() / preds.size(0)

    # ミニバッチに対するトレーニングの関数
    # 'loss'をキーにしないとバックワードが入らない 
    def training_step(self, batch, batch_nb):
        images, labels = batch
        preds = self.forward(images)
        return {'loss': F.nll_loss(preds, labels)}

    # ミニバッチに対するバリデーションの関数
    def validation_step(self, batch, batch_nb):
        images, labels = batch
        preds = self.forward(images)
        return {'val_nll_loss': F.nll_loss(preds, labels),
                'val_accuracy': self._accuracy(preds, labels)}

    # バリデーションループが終わったときに実行される関数
    def validation_end(self, outputs):
        avg_val_accuracy = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        return {'avg_val_accuracy': avg_val_accuracy}

    # 最適化アルゴリズムの指定
    def configure_optimizers(self):
        return [torch.optim.Adam(self.parameters(), lr=0.001)]

    # データローダーの定義
    @ptl.data_loader  
    def tng_dataloader(self):
        return DataLoader(
            MNIST(
                os.getcwd(),
                train=True,
                download=True,
                transform=transforms.ToTensor()),
            batch_size=32)

    @ptl.data_loader
    def val_dataloader(self):
        return DataLoader(
            MNIST(
                os.getcwd(),
                train=True,
                download=True,
                transform=transforms.ToTensor()),
            batch_size=32)

    @ptl.data_loader
    def test_dataloader(self):
        return DataLoader(
            MNIST(
                os.getcwd(),
                train=True,
                download=True,
                transform=transforms.ToTensor()),
            batch_size=32)

modelの定義はこんな感じで色々詰め込んであります。割とボイラープレートが少なくてサラっとかけそうです。
実際に学習を進めるコードはこちら。

fit.py
from models.cnn import CNN
import os
from pytorch_lightning import Trainer
from test_tube import Experiment

model = CNN()
exp = Experiment(save_dir=os.getcwd())

trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)
trainer.fit(model)

コード数もだいぶ少なくて便利そうですね。ちなみに、このコードを実行すればTensorBord用のログが保存されます。実行中の画面はこんな感じ。

スクリーンショット 2019-08-06 10.35.57.png

TensorBordの実行結果はこちら。
スクリーンショット 2019-08-06 10.39.53.png
こんな感じで、学習、ロギングもまるっとやってくれます。便利ですねー

所感

簡単なモデルをサクッと作るのは良さそうですし、ログとかもよしなに勝手に保存してくれるのは、簡単にTensorBoardで可視化できるので良いですね。
ただ、GANみたいに複数のロス関数に対してbackwardを走らせたいときとかはどうするんだろう、というのがまだわからないですね。その辺りもラッパーの恩恵を受けられると良いけど。。。って感じです。

まとめ

今回はpytorch-lightningを使ってCNNを学習してみました。興味ある人はぜひ使ってみてください!

yamad07
機械学習の研究をしています。
http://yamad07.net
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away