はじめに
今時にAIシステムを開発するにはよく知られているフレームワークは torch でしょう。動的な計算図(dynamic computation graph)を基盤として融通のきくAIモデルも構築できるのは主なメリットです。コミュニティにたくさんの3-rdライブラリを支援する上に、大体に車輪の再発明を防ぎます。しかし、torchはほかのAIフレームワークより融通のきく一方で、いくつかのところ(訓練ルプーにモデルのデバイスとか、データローダの仕組みとか)にエンジニアリングのコードを書かなければならない。
このシリーズの主役はtorchを基に研究員向けに開発したAIフレームワークtorchlightningです。従来の訓練パイプラインにリファクタリングして、大幅にエンジニアリングのコードを減少します。その故に、訓練に関する仕組みと準備は楽になります(何かtorch+lightningという狙いがありそうですw)。誰でもprint('hello world')
のようにAIシステムの開発を入門できます。
今回の解説はtorchlightningを主役として、以下のテーマを順番に説明します(yeah ~):
1.基本的な設定
2. torchlightningを味わてみる
基本的な設定
まずはpythonの環境を適当に設定する方法を伝えます。いくつかの方法はありますが、筆者に一番好みの仕方はcontainerで環境を作ります。
1.Build from Dockerfile
(1) 以下のDockerfileをビルドして:docker build -t joseph/ntch .
# NGC container, one of the best choice in the world ~
FROM nvcr.io/nvidia/pytorch:23.01-py3
# setup my cmd alias
COPY ./aliases /root/.bash_aliases
# install utils
RUN apt update && \
apt install vim -y && \
apt install less -y && \
apt install tmux -y && \
apt install build-essential -y
# setup root ssh pwd
COPY ./def_pwd /root/pwd
# ssh setting
RUN apt install openssh-server -y && \
cat /root/pwd | chpasswd && \
rm -f /root/pwd && \
sed -i "s/#Port.*/Port 22/" /etc/ssh/sshd_config && \
sed -i "s/#PermitRootLogin.*/PermitRootLogin yes/" /etc/ssh/sshd_config && \
sed -i "s/#PasswordAuthentication.*/PasswordAuthentication yes/" /etc/ssh/sshd_config && \
echo 'service ssh restart' >> ~/.bashrc
# expose the port 22(which is the default port of ssh)
EXPOSE 22
# set entrypoint to restart ssh automatically
ENTRYPOINT service ssh restart && /bin/bash
(2) そしてcontainerを起動して、
docker run -td --gpus '"device=0, 1"' -p 9478:22 --name dkr9478 --ipc=host -v ~/Desktop:/workspace -v /4TB-data:/storage joseph/ntch:latest
(3) sshでcontainerをログインして、
ssh root@localhost -p 9478
(4) そしてpipで必要なパッケージをインストールして、
pip install pytorch-lightning -q
NCG containerに
which pip
でpip
を追跡すると、pip
は実にpip3
のsoftlinkですw
python2のpip
を誤作動する恐れはありません~
2.Anaconda あるいは miniconda
普通にAnacondaのウェブサイトに最新版のインストーラしかいません。直接にarchiveのリストを開いてみましょう
https://repo.anaconda.com/archive/
Anacondaは便利ですが、スペースは少ないために軽いバージョンminicondaをお勧めますよ~
https://docs.conda.io/en/latest/miniconda.html
以下はminicondaを例として、scriptを執行してみてインストールします:
(1) wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh
(2) sh Miniconda3-py310_23.1.0-1-Linux-x86_64.sh
当然にyesと答える!!
(3) conda create --name tchl python==3.9
conda activate tchl ; pip install pytorch-lightning -q
それでは、環境の準備は以上です。
torchlightningを味わてみる
まずはtoy-projectを作りましょう。
torchlightningは自動的に必要なtorchパケットを備えますが、torchに関するほかのパケットのバージョンを合わないかもしれません。やはり自分でstableバージョンの情報を検査する方が心強いです。
pre-requirement:
# cat req.txt
pip install pytorch-lightning==1.9.3
# torchに関するほかのパケット
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchtext==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu116
# timmはあまりそういう罠がない
pip install timm
pip install easy_configer
そして、toy-projectの組み立ては:
>> ls -al toy_proj
エントリーポイントはtrain.pyですので、コードをみましょう:
from data import TinyImagenetDataModule
from timm.models.vision_transformer import vit_small_patch16_224
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
import pytorch_lightning as pl
import torch
import os
from easy_configer.Configer import Configer
class ViTModule(pl.LightningModule):
def __init__(self, mixup: bool = False, schedule_lr_by_step: bool = False, num_classes: int = None) -> None:
super().__init__()
self.schedule_lr_by_step = schedule_lr_by_step
self.model = vit_base_patch16_224(drop_rate=0.1, weight_init='jax', num_classes=num_classes)
self.criterion = MixupLoss( torch.nn.CrossEntropyLoss() ) if mixup \
else torch.nn.CrossEntropyLoss()
def configure_optimizers(self):
opt = torch.optim.Adam(self.model.parameters(), lr=0.0001)
if self.schedule_lr_by_step:
num_steps = self.trainer.estimated_stepping_batches
interval = 'step'
else:
num_steps = self.trainer.max_epochs
interval = 'epoch'
num_warmups = int(num_steps * 0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, num_steps, num_warmups)
scheduler = {'scheduler': scheduler, 'interval': interval}
return [opt], [scheduler]
def training_step(self, batch, batch_idx):
inputs, targets = batch
logits = self.model(inputs)
if self.mixup:
loss = self.criterion(logits, **targets)
else:
loss = self.criterion(logits, targets)
return loss
def get_config_str():
return '''
seed = 42@int
[data_loader]
data_root = /data/tiny-imgnet
batch_size = 128@int
[module]
schedule_lr_by_step = False@bool
mixup = False@bool
[trainer]
ckpt = @str
max_epochs = 100@int
devices = 1@int
accelerator = gpu@str
strategy = ddp@str
accumulate_grad_batches = 4@int
precision = 16@int
'''
if __name__ == '__main__':
# read config
cfg = Configer()
cfg_str = get_config_str()
cfg.cfg_from_str(cfg_str)
# fixed training results, prepare dataset, prepare model
seed_everything(cfg.seed)
dm = TinyImagenetDataModule( **cfg.data_loader )
model = ViTModule(**args.module, num_classes=dm.num_classes)
# combine all components and run the trainer!!
trainer = pl.Trainer(
resume_from_checkpoint=args.trainer['ckpt'],
callbacks=[LearningRateMonitor('step')],
gradient_clip_val=1.0,
**args.trainer
)
trainer.fit(model, datamodule=dm)
このファイルはTinyImagenetデータセットにViTを訓練します。訓練に必要なコンポーネント:
- ViT model
- TinyImagenet Dataset
ViT model
よく知られた論文 Transformer is all you need で、self-attention モジュールもCVの分野に使われてViTっという方法を提出しました。ここにtimmライブラリーを使って ViT の実装を導入します (self.model = vit_base_patch16_224(..., num_classes=num_classes)
)。
自分で
vit_base_patch16_224
を実現するのもいいですが、timmはすでによく知らせたモデルを提供します。ちなみに、timmのモデルはtorch.nn.Module
を継承して作るものです。
そして、pl.LightningModule
を継承していくつかのメソッドを書き直します。大雑把に言うと、モデルは訓練と推論って二つの状態があります(訓練の流れは推論のと違いところはあります)。training_step
メソッドは訓練の流れを定義して、訓練のうちにモデルを評価するためにvalidation_step
メソッドもあります。さらに、推論を行うときにモデル状態の更新を防ぐために、prediction_step
メソッドに推論の流れを定義します。
TinyImagenet Dataset
TinyImagenetDataModule( **cfg.data_loader )
はpl.LightningDataModule
に継承していくつかのメソッドをオーバードライブするもので、pl. Trainer
に相性が良い。そもそもpl.LightningDataModule
もtorchのモジュールを継承して、いくつかのメソッドを書き直しただけです。気楽に必要なコンポーネントを用意できます。
参考資料とリンク