0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

初めてのtorchlightningシリーズー”torch + lightning” とは何か

Last updated at Posted at 2023-02-25

はじめに

今時にAIシステムを開発するにはよく知られているフレームワークは torch でしょう。動的な計算図(dynamic computation graph)を基盤として融通のきくAIモデルも構築できるのは主なメリットです。コミュニティにたくさんの3-rdライブラリを支援する上に、大体に車輪の再発明を防ぎます。しかし、torchはほかのAIフレームワークより融通のきく一方で、いくつかのところ(訓練ルプーにモデルのデバイスとか、データローダの仕組みとか)にエンジニアリングのコードを書かなければならない。

このシリーズの主役はtorchを基に研究員向けに開発したAIフレームワークtorchlightningです。従来の訓練パイプラインにリファクタリングして、大幅にエンジニアリングのコードを減少します。その故に、訓練に関する仕組みと準備は楽になります(何かtorch+lightningという狙いがありそうですw)。誰でもprint('hello world')のようにAIシステムの開発を入門できます。
image.png

今回の解説はtorchlightningを主役として、以下のテーマを順番に説明します(yeah ~):
1.基本的な設定
2. torchlightningを味わてみる

基本的な設定

まずはpythonの環境を適当に設定する方法を伝えます。いくつかの方法はありますが、筆者に一番好みの仕方はcontainerで環境を作ります。

1.Build from Dockerfile

(1) 以下のDockerfileをビルドして:docker build -t joseph/ntch .

# NGC container, one of the best choice in the world ~
FROM nvcr.io/nvidia/pytorch:23.01-py3

# setup my cmd alias
COPY ./aliases /root/.bash_aliases

# install utils
RUN apt update && \
        apt install vim -y && \
        apt install less -y && \
        apt install tmux -y && \
        apt install build-essential -y

# setup root ssh pwd
COPY ./def_pwd /root/pwd

# ssh setting
RUN apt install openssh-server -y && \
        cat /root/pwd | chpasswd && \
        rm -f /root/pwd && \
        sed -i "s/#Port.*/Port 22/" /etc/ssh/sshd_config && \
        sed -i "s/#PermitRootLogin.*/PermitRootLogin yes/" /etc/ssh/sshd_config && \
        sed -i "s/#PasswordAuthentication.*/PasswordAuthentication yes/" /etc/ssh/sshd_config && \
        echo 'service ssh restart' >> ~/.bashrc

# expose the port 22(which is the default port of ssh)
EXPOSE 22

# set entrypoint to restart ssh automatically
ENTRYPOINT service ssh restart && /bin/bash

(2) そしてcontainerを起動して、
docker run -td --gpus '"device=0, 1"' -p 9478:22 --name dkr9478 --ipc=host -v ~/Desktop:/workspace -v /4TB-data:/storage joseph/ntch:latest

(3) sshでcontainerをログインして、
ssh root@localhost -p 9478

(4) そしてpipで必要なパッケージをインストールして、
pip install pytorch-lightning -q

NCG containerにwhich pippipを追跡すると、pipは実にpip3のsoftlinkですw
python2のpipを誤作動する恐れはありません~

2.Anaconda あるいは miniconda

普通にAnacondaのウェブサイトに最新版のインストーラしかいません。直接にarchiveのリストを開いてみましょう
:link: https://repo.anaconda.com/archive/

Anacondaは便利ですが、スペースは少ないために軽いバージョンminicondaをお勧めますよ~
:link: https://docs.conda.io/en/latest/miniconda.html

以下はminicondaを例として、scriptを執行してみてインストールします:
image.png
(1) wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh
yk.JPG

(2) sh Miniconda3-py310_23.1.0-1-Linux-x86_64.sh
image.png
当然にyesと答える!!
yk.JPG

(3) conda create --name tchl python==3.9
conda activate tchl ; pip install pytorch-lightning -q

それでは、環境の準備は以上です。

torchlightningを味わてみる

まずはtoy-projectを作りましょう。
torchlightningは自動的に必要なtorchパケットを備えますが、torchに関するほかのパケットのバージョンを合わないかもしれません。やはり自分でstableバージョンの情報を検査する方が心強いです。
pre-requirement:

# cat req.txt
pip install pytorch-lightning==1.9.3
# torchに関するほかのパケット
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchtext==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu116
# timmはあまりそういう罠がない
pip install timm
pip install easy_configer

そして、toy-projectの組み立ては:
>> ls -al toy_proj
image.png

エントリーポイントはtrain.pyですので、コードをみましょう:

from data import TinyImagenetDataModule
from timm.models.vision_transformer import vit_small_patch16_224

from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
import pytorch_lightning as pl

import torch
import os
from easy_configer.Configer import Configer

class ViTModule(pl.LightningModule):
    def __init__(self, mixup: bool = False, schedule_lr_by_step: bool = False, num_classes: int = None) -> None:
        super().__init__()
        self.schedule_lr_by_step = schedule_lr_by_step
        self.model = vit_base_patch16_224(drop_rate=0.1, weight_init='jax', num_classes=num_classes)
        self.criterion =  MixupLoss( torch.nn.CrossEntropyLoss() ) if mixup \
                            else torch.nn.CrossEntropyLoss()

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.model.parameters(), lr=0.0001)
        if self.schedule_lr_by_step:
            num_steps = self.trainer.estimated_stepping_batches
            interval = 'step'
        else:
            num_steps = self.trainer.max_epochs
            interval = 'epoch'
        num_warmups = int(num_steps * 0.1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, num_steps, num_warmups)
        scheduler = {'scheduler': scheduler, 'interval': interval}
        return [opt], [scheduler]

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        logits = self.model(inputs)
        if self.mixup:
            loss = self.criterion(logits, **targets)
        else:
            loss = self.criterion(logits, targets)
        return loss

def get_config_str():
    return '''
    seed = 42@int

    [data_loader] 
    data_root = /data/tiny-imgnet        
    batch_size = 128@int

    [module]
    schedule_lr_by_step = False@bool
    mixup = False@bool
    
    [trainer]
    ckpt = @str
    
    max_epochs = 100@int
    devices = 1@int
    accelerator = gpu@str
    strategy = ddp@str
    accumulate_grad_batches = 4@int
    precision = 16@int
    '''


if __name__ == '__main__':
    # read config
    cfg = Configer()
    cfg_str = get_config_str()
    cfg.cfg_from_str(cfg_str)
       
    # fixed training results, prepare dataset, prepare model
    seed_everything(cfg.seed)
    dm = TinyImagenetDataModule( **cfg.data_loader )
    model = ViTModule(**args.module, num_classes=dm.num_classes)
       
    # combine all components and run the trainer!!
    trainer = pl.Trainer(
        resume_from_checkpoint=args.trainer['ckpt'],
        callbacks=[LearningRateMonitor('step')],
        gradient_clip_val=1.0,
        **args.trainer
    )
    trainer.fit(model, datamodule=dm)

このファイルはTinyImagenetデータセットにViTを訓練します。訓練に必要なコンポーネント:

  1. ViT model
  2. TinyImagenet Dataset

ViT model

よく知られた論文 Transformer is all you need で、self-attention モジュールもCVの分野に使われてViTっという方法を提出しました。ここにtimmライブラリーを使って ViT の実装を導入します (self.model = vit_base_patch16_224(..., num_classes=num_classes))。

自分でvit_base_patch16_224を実現するのもいいですが、timmはすでによく知らせたモデルを提供します。ちなみに、timmのモデルはtorch.nn.Moduleを継承して作るものです。

そして、pl.LightningModuleを継承していくつかのメソッドを書き直します。大雑把に言うと、モデルは訓練と推論って二つの状態があります(訓練の流れは推論のと違いところはあります)。training_stepメソッドは訓練の流れを定義して、訓練のうちにモデルを評価するためにvalidation_stepメソッドもあります。さらに、推論を行うときにモデル状態の更新を防ぐために、prediction_stepメソッドに推論の流れを定義します。

TinyImagenet Dataset

TinyImagenetDataModule( **cfg.data_loader )pl.LightningDataModule に継承していくつかのメソッドをオーバードライブするもので、pl. Trainer に相性が良い。そもそもpl.LightningDataModule もtorchのモジュールを継承して、いくつかのメソッドを書き直しただけです。気楽に必要なコンポーネントを用意できます。


参考資料とリンク

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?