1.はじめに
この記事は主にpytorch_lightningのcallbackや学習がどの様に呼び出されているのかに注力を当ててコードを追って見ました.また,wandbloggerというものがどの様に実装されているのかを確認しました.しかし,この記事はメモ書きのようにとっ散らかっているため,参考程度にしていただけると幸いです.
1.0 結論
- wandbloggerもwandbも同じ挙動を再現できる.
- Loopというクラスで定義されている部分が肝でpytorch_lightningの学習は進んでいく.
- 完全に理解するには,なかなかハードであると感じた.
1.1. wandbについて
- wandbというMLの実験管理ツールをご存知でしょうか.
- MLを用いた実験をクラウド上で管理できる様になります.
- モデルの管理やロードも比較的簡単に行えます.
1.2. pytorch-lightningの実装について
- pytorch-lightningの実装に関しては,直接wandbを用いるのではなく,WandbLoggerを使います(下がpseudo[/ˈsjuːdəʊ/]-codeです).
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
class MyModule(pl.LightningModule): # (A)
...
class MyCallback(Callback): # (B1)
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
...
wandb_logger=WandbLogger(**kwargs) # (C)
trainer=pl.Trainer(logger=wandb_logger,callbacks=Mycallback(),**kwargs)
model=MyModule() # (B1)
wandb_logger.watch(model, log="all") #(C)
trainer.fit(model, train_loader, val_loader) # (B2)
model = MyModule.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path
) # (B)
-
wandbを直接使用できない理由は,
- trainerの中でmodelを学習をするためで,wandbを直接こちらに組み込むには手間がかかるからです.
-
また,pytorch-lightning(pl)を用いることで,
- モデルの学習過程も一括で管理できる利点があるため,wandbを使いつつplを使いたいと考えています.
また,plを使うためこちらに関しても一部まとめたいと思います.
2. pytorch-lightningの全体について
- まずは,wandbloggerに関係ないplの基本的な部分をおさらいします.
2.1. pl.LightningModule (A)
こちらのクラスを継承することで学習や検証に関することを一括で定義できます.
- forward, configure_optimizers,training_step,validation_step,test_step
- などを定義するのがメインかと思います.
こちらの記事(PyTorch Lightning の API を勉強しよう by @ground0state)が非常に参考になると思います.
2.1.1 Callback について (B1)
- Callbackを親クラスにもつクラスをpl.Trainerのcallbacksとして持つことにより
validationのiteration毎にの動作を記述できます.これとloggerを組み合わせることで,iteration毎の挙動の理解も進めることができます(参考としてearly_stoppingのcodeのリンクを貼っておきます(こちら))..- Callbackのコードは,こちらで確認できますが,Trainerからの呼び出しに関しては(B)の章で確認してみることにします.
2.1.1 pl.Trainer について (B2)
コードはこちらです.一部抜粋して動作を大雑把に確認します. 親クラスに
-
TrainerCallbackHookMixin
- I, -
TrainerOptimizersMixin
- II,- No document, but I think this is about optimizer and just a summary.
-
TrainerDataLoadingMixin
- III
を持ちます.
2.1.1. __init__のデコレータについて (ソース)
- __init__には @_defaults_from_env_varsというデコレーターがついていることが確認できます.このデコレータのソースはこちらです.これにはwrapsというものが使用されていますが,こちらの記事が参考になると思います(Pythonのデコレータにはwrapsをつけるべきという覚え書き by @moonwalkerpoday(株式会社日本システム技研)).このデーコレータ自体のは,環境変数の値を変数として追加してくれる様です.
2.1.2. __init__の引数について (ソース)
-
全ての引数に初期値が与えられていることがわかります.全て追うのは大変なので,私的に注意したい値をpickoutしていきます.
-
logger
- 実験追跡のためのロガー(またはロガーの反復可能なコレクション).
- Trueの場合,デフォルトのTensorBoardLoggerが使用される.Falseにすると、ロギングを無効にします.
- 複数のロガーが提供され,そのロガーの save_dir プロパティが設定されていない場合,ローカルファイル(チェックポイント、プロファイラートレースなど)は,個々のロガーのlog_dir ではなく default_root_dir に保存されます。
-
default_root_dir
- logger/ckpt_callback が渡されない場合の、ログと重みのデフォルトパス.
- デフォルト:
os.getcwd()
. s3://mybucket/path` や 'hdfs://path/' のようなリモートファイルのパスを指定することができる.
-
callbacks
- コールバック コールバックまたはコールバックのリストを追加します.
- checkpoint_callback。もし
True
ならば、チェックポイントを有効にする. - ... 非推奨:: v1.5
-
checkpoint_callback
は v1.5 で非推奨となり、v1.7 で削除される予定である。代わりにenable_checkpointing
を使用することを検討してください.
-
-
enable_checkpointing
-
True
の場合、チェックポイントを有効にします - paramref:
~pytorch_lightning.trainer.Trainer.callbacks
にユーザ定義の ModelCheckpoint が存在しない場合、デフォルトの ModelCheckpoint コールバックが設定されます.
-
-
progress_bar_refresh_rate
- プログレスバーを更新する頻度(単位:ステップ).値
0
はプログレスバーを無効にする. paramref:~Trainer.callbacks
にカスタムのプログレスバーが渡された場合は、無視されます. - デフォルト:None, 環境(ターミナル、Google COLABなど)に応じて適切な値が選択されます.
- ... deprecated:: v1.5
progress_bar_refresh_rate
は v1.5 で廃止され、 v1.7 で削除される予定です。- 代わりに, :class:
~pytorch_lightning.callbacks.progress.TQDMProgressBar
にrefresh_rate
をつけて,Trainer のcallbacks
引数に直接渡してください.プログレスバーを無効にするには,enable_progress_bar = False
を Trainer に渡します.
- 代わりに, :class:
- プログレスバーを更新する頻度(単位:ステップ).値
-
enable_progress_bar
- Whether to enable to progress bar by default.
2.1.3. trainer.fit について (ソース)
- trainer.fit関数自体は以下のコードです(UnionやOptionalにある見慣れない型はpl/utilities/types.pyで定義されています).
def fit(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
self.training_type_plugin.model = model
self._call_and_handle_interrupt(
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)
- この中の
training_type_plugin
はこちらで定義されています.
@property
def training_type_plugin(self) -> TrainingTypePlugin:
return self._accelerator_connector.training_type_plugin
-
そして,その中の
_accelerator_connector
で定義されています.そしてそれはAcceleratorConnector
を元にしていて,training_type_plugin
が定義されています.また,これは,final_training_type_plugin
で定義されています...(以下略). -
コードが複雑で細かい部分は追えませんでしたが,公式doc(TrainingTypePlugin)と同じ様なものと考えて良さそうで,モデルの初期設定をした様なものなんだと思います(ここは完全に憶測です.時間がある時に正しく追いたいと思います).
-
そして,
_call_and_handle_interrupt
ではerrorの処理をしつつ(細かいところはパスします),第一引数のCallable(関数)を実装します. -
そのため,(
def fit
より)self._fit_impl
に注目すれば挙動の理解を進めることができます.
2.1.4. self._fit_impl について (ソース)
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
-
まず,
TrainerFn
を使って,trainerの状態を記述しているのがわかります. -
そして,次に
_data_conector
でデータをtrainerにリンクをする様です.- 具体的には,DataConectorのattach_dataの内容です.
-
そして,
_run
を実装します.
流れとしては
fr"""
Lightning internal flow looks like this:
{Trainer.fit} or {Trainer.test} or {Trainer.predict} ||
| ||
spawn processes ||
{self.training_type_plugin.setup_environment} ||
| ||
setup accelerator ||
and strategy || LIGHTNING
| ||
{self.run_stage} || FLOW
| ||
{self._run_train} || DIRECTION
or {self._run_evaluate} ||
or {self._run_predict} ||
| ||
results \/
This is used to guide readers to the core loops: train, test, predict.
{self._run_predict} is the simplest to understand, use `Go to Definition` to read it :)
"""
-
まず,環境に関するsetupを行います.
-
次にここで
run_stage
の実装を行います.これが肝で,_run_train
が実装されます.-
_run_train
の中でも今回はfit_loop
に注目します. -
これは,
FitLoop
を元に作られています. - また,
TrainingEpochLoop
やTrainingBatchLoop
,EvaluationLoop
が連結されて定義されています.- それぞれのLoop関数を見ると
trainer._call_callback_hooks
を用いて"on_train_batch_end"
や"on_validation_end"
が呼ばれているのがわかります. - 呼び出され方の詳細はこちらです(コード).
- また,
fit_loop.run
の定義はFitLoopの親クラスのLoopクラスのこちらです.
- それぞれのLoop関数を見ると
-
-
今回の章からcallbacksによって定義されたクラスの中の
on_validation_end
の呼ばれ方や学習の様子をどの様にコーディングするかを完璧とは言いませんが,確認することができました. -
次の章ではwandbloggerおよびloggerの使い方や設定方法について見ていきたいと思います.
3. WandbLoggerについて(C) (ソース)
- 親クラスは
LightningLoggerBase
です.- 細かくは追いませんでしたが,trainerからexperimentが呼び出されることが予測されます.
- 基本的にwandb.initと同じ挙動を行うことができます(_wandb_initとそれの呼び出しより)
- 具体的な実装は,このページのdemoが参考になると思います.
3.1. __init__の引数について (ソース)
name: Display name for the run.
save_dir: Path where data is saved (wandb dir by default).
offline: Run offline (data can be streamed later to wandb servers).
id: Sets the version, mainly used to resume a previous run.
version: Same as id.
anonymous: Enables or explicitly disables anonymous logging.
project: The name of the project to which this run will belong.
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
as W&B artifacts. `latest` and `best` aliases are automatically set.
* if ``log_model == 'all'``, checkpoints are logged during training.
* if ``log_model == True``, checkpoints are logged at the end of training, except when
:paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1``
which also logs every checkpoint during training.
* if ``log_model == False`` (default), no checkpoint is logged.
prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.