3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

pytorch-lightningの使い方

Posted at

はじめに

この記事ではpytorch-lightningの使い方を、いい感じのプログラムで紹介します

まず結論:ソースコードの書き方

pytorch_lightning_sample.py
import os
import pickle

import numpy as np
from PIL import Image                              # 画像を取り扱うために使用
import matplotlib.pyplot as plt                    # 画像のサンプル表示のために使用

import torch                                       # pytorch本体
import torch.nn as nn                              # ニューラルネットを構成する際の基本的なモジュールが入っている
from torchvision import transforms as transforms   # 画像前処理のために使用

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers

# ==== データローダの作成など ============================================================================================

# データセットに対して、idxで指定された際に読み込み方法を指定するためのラッパークラス
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        # 画像を変換して整形して保持
        self.data = np.array(data)                          # numpy形式に変換
        self.data = self.data.reshape(len(data), 3, 32, 32) # dataを整形
        self.data = self.data.transpose(0, 2, 3, 1)         # data[ミニバッチのindex][チャンネル][画像縦位置][画像横位置]と指定できるように順序交換

        # ラベルを保持
        self.labels = labels

        #  画像を前処理するための関数たちを登録
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)),
            ]
        )

    # 指定されたindexのデータを辞書形式で返却するように設定する
    def __getitem__(self, index):
        img, label = self.data[index], self.labels[index] # 指定のデータを取得
        img = Image.fromarray(img)                        # 画像に変換
        img = self.transform(img)                         # transformをかける(tensor型に変換してから、正規化)
        return {'inputs':img, 'targets':label}            # 辞書形式で返却(辞書のkeyはニューラルネットのforwardの引数と同じ名前にする)

    # データセットの個数を返すように設定する
    def __len__(self):
        return len(self.data)

    # サンプル表示用の関数。一般には作らなくてOK。
    def plot(self, index):
        classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        img, label = self.data[index], self.labels[index]
        plt.imshow(img)                         # 指定されたindexの画像を描画
        plt.title(f'label={classes[label]}')    # titleを"label=クラス名"という形式で設定
        plt.show()                              # 表示


class DataModule(pl.LightningDataModule):
    def __init__(self, dataset_path, split_ratio=(0.7, 0.1, 0.2), batch_size=32, thread_num=4):
        super().__init__()
        self.dataset_path = dataset_path
        self.split_ratio = split_ratio
        self.batch_size = batch_size
        self.thread_num = thread_num

    def setup(self, stage=None):
        # 成形されたデータセットの読み込み
        print(f'loading {self.dataset_path}')
        with open(self.dataset_path, 'rb') as f:
            dataset_raw = pickle.load(f, encoding='bytes')
        print('loading completed')

        # データセット形式に変換
        ds = Dataset(dataset_raw[b'data'], dataset_raw[b'labels'])
        ds.plot(index=np.random.randint(0, len(ds)))    # サンプルとして、ランダムに1件選んで描画

        # データセットを指定された比率に合わせて分割
        total_size = len(ds)
        train_size = int(total_size * self.split_ratio[0])  # 学習で使用するデータ個数
        valid_size = int(total_size * self.split_ratio[1])  # 検証で使用するデータ個数
        test_size  = int(total_size * self.split_ratio[2])  # テストで使用するデータ個数
        self.train_dataset = torch.utils.data.dataset.Subset(ds, range(0, train_size))                     # 指定された部分のデータのみを取り出す
        self.valid_dataset = torch.utils.data.dataset.Subset(ds, range(train_size, train_size+valid_size)) # 指定された部分のデータのみを取り出す
        self.test_dataset = torch.utils.data.dataset.Subset(ds, range(train_size+valid_size, total_size))  # 指定された部分のデータのみを取り出す
        print(f'dataset size: total {total_size}, train {train_size}, validation {valid_size}, test {test_size}')

    def train_dataloader(self):
        if len(self.train_dataset) != 0:
            return torch.utils.data.DataLoader(
                self.train_dataset,             # データセット
                batch_size=self.batch_size,     # イテレート時のバッチサイズ
                shuffle=True,                   # イテレート前にデータをシャッフルするか
                num_workers=self.thread_num,    # イテレート時に使用するスレッド数
                pin_memory=True,                # メモリを固定して高速化をするか
                drop_last=True,                 # 最後の端数部を落とすか
            )
        else:
            raise Exception('length of dataset is zero.')

    def val_dataloader(self):
        if len(self.valid_dataset) != 0:
            return torch.utils.data.DataLoader(
                self.valid_dataset,             # データセット
                batch_size=self.batch_size,     # イテレート時のバッチサイズ
                shuffle=False,                  # イテレート前にデータをシャッフルするか
                num_workers=self.thread_num,    # イテレート時に使用するスレッド数
                pin_memory=True,                # メモリを固定して高速化をするか
                drop_last=False,                # 最後の端数部を落とすか
            )
        else:
            raise Exception('length of dataset is zero.')

    def test_dataloader(self):
        if len(self.test_dataset) != 0:
            return torch.utils.data.DataLoader(
                self.test_dataset,              # データセット
                batch_size=self.batch_size,     # イテレート時のバッチサイズ
                shuffle=False,                  # イテレート前にデータをシャッフルするか
                num_workers=self.thread_num,    # イテレート時に使用するスレッド数
                pin_memory=True,                # メモリを固定して高速化をするか
                drop_last=False,                # 最後の端数部を落とすか
            )
        else:
            raise Exception('length of dataset is zero.')


# ====================================================================================================================

class LeNet(nn.Module):
    def __init__(self):
        # ここにはニューラルネットの構成で使用するモジュールを一通り書き出す。
        # Conv2dは畳み込み層、AvgPool2dは平均プーリング層、Flattenは数値を一列に並べてベクトル化する層、Linearは線形層(全結合層)を表している。使い方はググろう。

        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5, padding=0, stride=1)     # 32*32 3チャンネル入力 → 28*28 6チャンネル出力
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)                   # 28*28 6チャンネル入力 → 14*14 6チャンネル出力
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0, stride=1)    # 14*14 6チャンネル入力 → 10*10 16チャンネル出力
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)                   # 10*10 16チャンネル入力 → 5*5 16チャンネル出力
        self.flatten = nn.Flatten()                                          # 5*5 25チャンネル入力 → 400(=5*5*25)出力
        self.fc1 = nn.Linear(400, 120)      # 400入力, 120出力
        self.fc2 = nn.Linear(120, 84)       # 120入力, 84出力
        self.fc3 = nn.Linear(84, 10)        # 84入力, 10出力
        self.softmax = nn.Softmax(dim=1)    # ソフトマックス関数
        self.loss = nn.CrossEntropyLoss()   # クロスエントロピー損失

    def forward(self, inputs, targets):
        # ここには順伝搬のやり方を書く。

        h = inputs

        # 作用
        h = torch.sigmoid(self.conv1(h))
        h = self.pool1(h)
        h = torch.sigmoid(self.conv2(h))
        h = self.pool2(h)
        h = self.flatten(h)
        h = torch.sigmoid(self.fc1(h))
        h = torch.sigmoid(self.fc2(h))
        h = self.softmax(self.fc3(h))

        # 損失を計算
        loss = self.loss(h, targets)

        # データ形式をdictに包んで出力
        return {
            "loss":loss,          # 損失の値を記録
            "hidden_states":h,    # 最終層の値を記録
        }


# ====================================================================================================================

class LitModule(pl.LightningModule):
    # ネットワークモジュールなどの定義
    def __init__(self, learning_rate):
        super().__init__()
        self.model = LeNet()
        self.learning_rate = learning_rate
        self.save_hyperparameters()

    # 順伝搬の処理
    def forward(self, **x):
        return self.model(**x)

    # オプティマイザの定義
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0.1, total_iters=self.trainer.max_epochs)
        return [optimizer], [scheduler]

    # ==================================================================

    # 学習のバッチ実行処理
    def training_step(self, batch, batch_index):
        outputs = self.model(**batch)
        loss = outputs['loss']
        return {'loss': loss, 'correct': (torch.argmax(outputs["hidden_states"], dim=-1)==batch['targets']).to(torch.float32)}

    # 学習の全バッチ終了時の処理
    def training_epoch_end(self, outputs):
        train_loss = torch.hstack([dict_['loss'] for dict_ in outputs]).mean()
        train_accuracy = torch.hstack([dict_['correct'] for dict_ in outputs]).mean()
        self.log_dict({"train_loss": train_loss, "train_accuracy": train_accuracy})
        self.print({"train_loss": train_loss, "train_accuracy": train_accuracy})

    # 検証のバッチ実行処理
    def validation_step(self, batch, batch_index):
        with torch.no_grad():
            outputs = self.model(**batch)
        loss = outputs['loss']
        return {'loss': loss, 'correct': (torch.argmax(outputs["hidden_states"], dim=-1)==batch['targets']).to(torch.float32)}

    # 検証の全バッチ終了時の処理
    def validation_epoch_end(self, outputs):
        val_loss = torch.hstack([dict_['loss'] for dict_ in outputs]).mean()
        val_accuracy = torch.hstack([dict_['correct'] for dict_ in outputs]).mean()
        self.log_dict({"val_loss": val_loss, "val_accuracy": val_accuracy})
        self.print({"val_loss": val_loss, "val_accuracy": val_accuracy})

    # テストのバッチ実行処理
    def test_step(self, batch, batch_index):
        with torch.no_grad():
            outputs = self.model(**batch)
        loss = outputs['loss']
        return {'loss': loss, 'correct': (torch.argmax(outputs["hidden_states"], dim=-1)==batch['targets']).to(torch.float32)}

    # テストの全バッチ終了時の処理
    def test_epoch_end(self, outputs):
        test_loss = torch.hstack([dict_['loss'] for dict_ in outputs]).mean()
        test_accuracy = torch.hstack([dict_['correct'] for dict_ in outputs]).mean()
        self.print({"test_loss": test_loss, "test_accuracy": test_accuracy})

    # ==================================================================


# ====================================================================================================================


def main():
    save_dir = './result'

    data_module = DataModule(
        dataset_path='./cifar-10-batches-py/data_batch_1',
        split_ratio=(0.7, 0.1, 0.2),
        batch_size=32,
    )
    model = LitModule(learning_rate=0.0001)

    callbacks = [
        ModelCheckpoint(
            dirpath=save_dir,
            filename='epoch{epoch:02d}-val_loss{val_loss:.2f}', # チェックポイントのファイル名の形式
            monitor='val_loss',                                 # 基準とする量
            mode="min",                                         # 最小となるところを探す
        ),  # modelのチェックポイントを作成
        # EarlyStopping(
        #     monitor="val_loss",                                 # 基準とする量
        #     mode="min",                                         # 最小となるところを探す
        # ),  # early-stoppingを利用
    ]
    trainer = pl.Trainer(
        max_epochs=10,
        logger=[pl_loggers.TensorBoardLogger(save_dir=save_dir)],
        callbacks=callbacks,
        accelerator='gpu',
        devices=[0],                            # 使用するGPUのIDのリスト
        # auto_lr_find=True,                      # learning rateを自動で設定するか
        # accumulate_grad_batches=1,              # 勾配を累積して一度に更新することでバッチサイズを仮想的にN倍にする際のN
        # gradient_clip_val=1,                    # 勾配クリッピングの値
        # fast_dev_run=True,                      # デバッグ時にonにすると、1回だけtrain,validを実行する
        # overfit_batches=1.0,                    # デバッグ時にonにすると、train = validで学習が進み、過学習できているかを確認できる
        # deterministic=True,                     # 再現性のために乱数シードを固定するか
        # resume_from_checkpoint='bbb/aaa.ckpt',  # チェックポイントから再開する場合に利用
        # precision=16,                           # 小数を何ビットで表現するか
        # amp_backend="apex",                     # 少数の混合方式を使用するかどうか。nvidiaのapexがインストールされている必要あり。
        # benchmark=True,                         # cudnn.benchmarkを使用して高速化するか(determisticがTrueの場合はFalseに上書きされる)
    )
    # trainer.tune(model, datamodule=data_module)   # 「auto_lr_find=True」を指定した場合に実行する
    trainer.fit(model, datamodule=data_module)
    trainer.test(model, datamodule=data_module)

if __name__ == '__main__':
    main()

終わりに

  • とりあえずメモ書き程度で残しておきます
  • 今後もう少し説明を書くかも
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?