Edited at

速習 pytorch-lightning: 今すぐ機械学習の実験をしたいそこのキミへ


概要

pytorch lightningは、PyTorchでの開発スピードを爆速にしてくれるライブラリです。今回はこのライブラリを使ってサクッとCNNを実装していきます。


コード


model.py

import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as ptl

class CNN(ptl.LightningModule):
# モデルの定義(PyTorchと一緒)
def __init__(self):
super(CNN, self).__init__()
self.c1 = nn.Sequential(
nn.Conv2d(1, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.c2 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.linear = nn.Linear(12544, 10)
# フィードフォワードの計算の定義(PyTorchと一緒)
def forward(self, image):

h = self.c1(image)
h = self.c2(h)

batch_size = h.size(0)
h = self.linear(h.view(batch_size, -1))
return F.log_softmax(h, dim=1)

# accuacyを計算するためのプライベート関数
def _accuracy(self, preds, labels):
_, preds = torch.max(preds, dim=1)
return (preds == labels).sum().float() / preds.size(0)

# ミニバッチに対するトレーニングの関数
# 'loss'をキーにしないとバックワードが入らない
def training_step(self, batch, batch_nb):
images, labels = batch
preds = self.forward(images)
return {'loss': F.nll_loss(preds, labels)}

# ミニバッチに対するバリデーションの関数
def validation_step(self, batch, batch_nb):
images, labels = batch
preds = self.forward(images)
return {'val_nll_loss': F.nll_loss(preds, labels),
'val_accuracy': self._accuracy(preds, labels)}

# バリデーションループが終わったときに実行される関数
def validation_end(self, outputs):
avg_val_accuracy = torch.stack([x['val_accuracy'] for x in outputs]).mean()
return {'avg_val_accuracy': avg_val_accuracy}

# 最適化アルゴリズムの指定
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.001)]

# データローダーの定義
@ptl.data_loader
def tng_dataloader(self):
return DataLoader(
MNIST(
os.getcwd(),
train=True,
download=True,
transform=transforms.ToTensor()),
batch_size=32)

@ptl.data_loader
def val_dataloader(self):
return DataLoader(
MNIST(
os.getcwd(),
train=True,
download=True,
transform=transforms.ToTensor()),
batch_size=32)

@ptl.data_loader
def test_dataloader(self):
return DataLoader(
MNIST(
os.getcwd(),
train=True,
download=True,
transform=transforms.ToTensor()),
batch_size=32)


modelの定義はこんな感じで色々詰め込んであります。割とボイラープレートが少なくてサラっとかけそうです。

実際に学習を進めるコードはこちら。


fit.py

from models.cnn import CNN

import os
from pytorch_lightning import Trainer
from test_tube import Experiment

model = CNN()
exp = Experiment(save_dir=os.getcwd())

trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)
trainer.fit(model)


コード数もだいぶ少なくて便利そうですね。ちなみに、このコードを実行すればTensorBord用のログが保存されます。実行中の画面はこんな感じ。

スクリーンショット 2019-08-06 10.35.57.png

TensorBordの実行結果はこちら。

スクリーンショット 2019-08-06 10.39.53.png

こんな感じで、学習、ロギングもまるっとやってくれます。便利ですねー


所感

簡単なモデルをサクッと作るのは良さそうですし、ログとかもよしなに勝手に保存してくれるのは、簡単にTensorBoardで可視化できるので良いですね。

ただ、GANみたいに複数のロス関数に対してbackwardを走らせたいときとかはどうするんだろう、というのがまだわからないですね。その辺りもラッパーの恩恵を受けられると良いけど。。。って感じです。


まとめ

今回はpytorch-lightningを使ってCNNを学習してみました。興味ある人はぜひ使ってみてください!