6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorch Lightningのはじめの一歩

Last updated at Posted at 2022-12-04

はじめに

これは、PyTorchは少し知っているけど、PyTorch Lightningは全く知らない人が、PyTorch Lightningを勉強してみましたという話です。
公式のサンプルを元にPyTorchの初歩的な部分を解説しています。

この記事はHEROZ社内での勉強会の内容をリライトしたものです。

Pytorch Lightningについて

Pytorch Lightningについて簡単に概要を触れておくと、Pytorch LightningはPytorchのラッパーで、
学習ループなどの定型文(boilerplate)をラッピングし学習周りのコードを簡潔にわかりやすく書けるようにするライブラリです。

類似のフレームワークとしては、以下のものがあります。

  • Ignite
  • Catalyst

最近は、PyTorch Lightningがだいぶ優勢で、利用例や解説記事も多く見つかります。

Pytorch Lightningの基本

以下の3つの部分がPytorch Lightningの基本です。

  • データに関する部分(LightningDataModule)
  • 学習に関する部分(LightningModule)
  • 実際の学習を行う部分(Trainer)

LightningDataModule

データ周りの処理を集約したクラスです。
以下の公式のサンプルのように、train/validation/testのそれぞれで使用するPyTorchのDataLoaderを返すメソッドを作成します。

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False)
        self.mnist_predict = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        ...

公式の例だとMNISTを使っているのでシンプルに見えますが、実際には普通にPyTorchのDatasetを書く必要があるので、普通にPyTorchで書くのと大差はありません。

LigntningModule

学習で行うことを集約したクラスです。
以下の公式のサンプルのように、モデルの定義、forwardの定義、train/val/testの各stepで実行すること、Optimizerの定義を記載します。

import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

こちらも、普通にPyTorchで書くのと大差はなく、コードが集約されることが最大の利点です。

Trainer

最後に、作成したLightningDataModuleとLightningModuleをTrainerに渡し、fitを呼べば学習が開始します。
https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction.html#train-the-model

PyTorchで学習するコードを書くには、このあたりいろいろと書かなくてはいけないですが、Trainer作ってfitを呼べばとりあえず学習は始まります。

また、Trainerには、以下のよいところがあります。

  • ModelCheckPointやEarlyStoppingなどの基本的な機能は実装されており、callbacksに指定するのみ
  • TensorBoardやMLFlow等の様々な実験管理ライブラリに対応しており、loggerに指定するのみ
  • デバイス依存のコード(CPU/GPU/TPU)も楽にかける

このあたりはいろいろと用意されていて、モデルの開発に注力できるというところがPyTorch Lightningの利点なのかなというように思います。

コンピュータ将棋界での使用

コンピュータ将棋に関連したところでもちらほらPyTorch Lightningを使用しているケースがありますので、ちょっと紹介します。

チェスソフトのStockfishの学習ルーチンであるnnue-pytorchにPytorch Lightningが使用されています。
https://github.com/glinscott/nnue-pytorch

こちらは、nodchipさんの手によって将棋の学習もできるようになっています。
https://github.com/nodchip/nnue-pytorch

また、プロ将棋棋士でもある谷合さんがヨビノリたくみさんとのYoutube対戦のために作成された、Bertの将棋ソフトでも学習にPytorch Lightningが使用されています。
https://github.com/nyoki-mtl/bert-mcts-youtube

まとめ

PyTorch Lightningを使用する利点としては、

  • 利用例がたくさんあり、調べればたくさんの情報がある
  • PyTorchだけで書くより楽にコードが書ける
  • コードの書き方がある程度統一される

また、欠点は、

  • PyTorch Lightningの作法を覚えないといけない
  • PyTorch Lightningで用意されていないものを使おうとすると結構大変

このあたりは、フレームワークを使う上では避けられない点で、この点が気にならなければ良いフレームワークかなという気がしました。

参考文献

最後に、この記事を書くにあたり参考にしたWeb記事のリンクを張っておきます。どれも、Pytorch Lightningを知るのにとても勉強になりました。

おまけ

HEROZでは、2022年7月に将棋解析サービス、棋神アナリティクスをリリースしました。
https://kishin-analytics.heroz.jp/
最新・最強の将棋ソフトでお手軽に将棋の解析ができるサービスです。

さらに、2022年12月に月額1100円から使用できるプランもリリースしました。
アカウント作成から30日以内の間、無料で1時間までお試しで解析できるので、ぜひ一度試していただければうれしいです。

6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?