LoginSignup
44
31

More than 3 years have passed since last update.

速習 pytorch-lightning: 今すぐ機械学習の実験をしたいそこのキミへ

Last updated at Posted at 2019-08-06

概要

pytorch lightningは、PyTorchでの開発スピードを爆速にしてくれるライブラリです。今回はこのライブラリを使ってサクッとCNNを実装していきます。

コード

model.py
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as ptl


class CNN(ptl.LightningModule):
    # モデルの定義(PyTorchと一緒)
    def __init__(self):
        super(CNN, self).__init__()
        self.c1 = nn.Sequential(
            nn.Conv2d(1, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.c2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.linear = nn.Linear(12544, 10)
    # フィードフォワードの計算の定義(PyTorchと一緒)
    def forward(self, image):

        h = self.c1(image)
        h = self.c2(h)

        batch_size = h.size(0)
        h = self.linear(h.view(batch_size, -1))
        return F.log_softmax(h, dim=1)

    # accuacyを計算するためのプライベート関数
    def _accuracy(self, preds, labels):
        _, preds = torch.max(preds, dim=1)
        return (preds == labels).sum().float() / preds.size(0)

    # ミニバッチに対するトレーニングの関数
    # 'loss'をキーにしないとバックワードが入らない 
    def training_step(self, batch, batch_nb):
        images, labels = batch
        preds = self.forward(images)
        return {'loss': F.nll_loss(preds, labels)}

    # ミニバッチに対するバリデーションの関数
    def validation_step(self, batch, batch_nb):
        images, labels = batch
        preds = self.forward(images)
        return {'val_nll_loss': F.nll_loss(preds, labels),
                'val_accuracy': self._accuracy(preds, labels)}

    # バリデーションループが終わったときに実行される関数
    def validation_end(self, outputs):
        avg_val_accuracy = torch.stack([x['val_accuracy'] for x in outputs]).mean()
        return {'avg_val_accuracy': avg_val_accuracy}

    # 最適化アルゴリズムの指定
    def configure_optimizers(self):
        return [torch.optim.Adam(self.parameters(), lr=0.001)]

    # データローダーの定義
    @ptl.data_loader  
    def tng_dataloader(self):
        return DataLoader(
            MNIST(
                os.getcwd(),
                train=True,
                download=True,
                transform=transforms.ToTensor()),
            batch_size=32)

    @ptl.data_loader
    def val_dataloader(self):
        return DataLoader(
            MNIST(
                os.getcwd(),
                train=True,
                download=True,
                transform=transforms.ToTensor()),
            batch_size=32)

    @ptl.data_loader
    def test_dataloader(self):
        return DataLoader(
            MNIST(
                os.getcwd(),
                train=True,
                download=True,
                transform=transforms.ToTensor()),
            batch_size=32)

modelの定義はこんな感じで色々詰め込んであります。割とボイラープレートが少なくてサラっとかけそうです。
実際に学習を進めるコードはこちら。

fit.py
from models.cnn import CNN
import os
from pytorch_lightning import Trainer
from test_tube import Experiment

model = CNN()
exp = Experiment(save_dir=os.getcwd())

trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)
trainer.fit(model)

コード数もだいぶ少なくて便利そうですね。ちなみに、このコードを実行すればTensorBord用のログが保存されます。実行中の画面はこんな感じ。

スクリーンショット 2019-08-06 10.35.57.png

TensorBordの実行結果はこちら。
スクリーンショット 2019-08-06 10.39.53.png
こんな感じで、学習、ロギングもまるっとやってくれます。便利ですねー

所感

簡単なモデルをサクッと作るのは良さそうですし、ログとかもよしなに勝手に保存してくれるのは、簡単にTensorBoardで可視化できるので良いですね。
ただ、GANみたいに複数のロス関数に対してbackwardを走らせたいときとかはどうするんだろう、というのがまだわからないですね。その辺りもラッパーの恩恵を受けられると良いけど。。。って感じです。

まとめ

今回はpytorch-lightningを使ってCNNを学習してみました。興味ある人はぜひ使ってみてください!

44
31
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
44
31