2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch-Lightningとそのwandblogger について(メモ)

Posted at

1.はじめに

この記事は主にpytorch_lightningのcallbackや学習がどの様に呼び出されているのかに注力を当ててコードを追って見ました.また,wandbloggerというものがどの様に実装されているのかを確認しました.しかし,この記事はメモ書きのようにとっ散らかっているため,参考程度にしていただけると幸いです.

1.0 結論

  • wandbloggerもwandbも同じ挙動を再現できる.
  • Loopというクラスで定義されている部分が肝でpytorch_lightningの学習は進んでいく.
  • 完全に理解するには,なかなかハードであると感じた.

1.1. wandbについて

1.2. pytorch-lightningの実装について

  • pytorch-lightningの実装に関しては,直接wandbを用いるのではなく,WandbLoggerを使います(下がpseudo[/ˈsjuːdəʊ/]-codeです).
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
class MyModule(pl.LightningModule): # (A)
   ...
class MyCallback(Callback): # (B1)
   def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
      ...
wandb_logger=WandbLogger(**kwargs) # (C)
trainer=pl.Trainer(logger=wandb_logger,callbacks=Mycallback(),**kwargs)
model=MyModule() # (B1)
wandb_logger.watch(model, log="all") #(C)
trainer.fit(model, train_loader, val_loader) # (B2)
model = MyModule.load_from_checkpoint(
       trainer.checkpoint_callback.best_model_path
) # (B)
  • wandbを直接使用できない理由は,

    • trainerの中でmodelを学習をするためで,wandbを直接こちらに組み込むには手間がかかるからです.
  • また,pytorch-lightning(pl)を用いることで,

    • モデルの学習過程も一括で管理できる利点があるため,wandbを使いつつplを使いたいと考えています.

また,plを使うためこちらに関しても一部まとめたいと思います.

2. pytorch-lightningの全体について

  • まずは,wandbloggerに関係ないplの基本的な部分をおさらいします.

2.1. pl.LightningModule (A)

こちらのクラスを継承することで学習や検証に関することを一括で定義できます.

  • forward, configure_optimizers,training_step,validation_step,test_step
  • などを定義するのがメインかと思います.

こちらの記事(PyTorch Lightning の API を勉強しよう by @ground0state)が非常に参考になると思います.

2.1.1 Callback について (B1)

  • Callbackを親クラスにもつクラスをpl.Trainerのcallbacksとして持つことにより
    validationのiteration毎にの動作を記述できます.これとloggerを組み合わせることで,iteration毎の挙動の理解も進めることができます(参考としてearly_stoppingのcodeのリンクを貼っておきます(こちら))..
    • Callbackのコードは,こちらで確認できますが,Trainerからの呼び出しに関しては(B)の章で確認してみることにします.

2.1.1 pl.Trainer について (B2)

コードはこちらです.一部抜粋して動作を大雑把に確認します. 親クラスに

を持ちます.

2.1.1. __init__のデコレータについて (ソース)

2.1.2. __init__の引数について (ソース)

  • 全ての引数に初期値が与えられていることがわかります.全て追うのは大変なので,私的に注意したい値をpickoutしていきます.

  • logger

    • 実験追跡のためのロガー(またはロガーの反復可能なコレクション).
    • Trueの場合,デフォルトのTensorBoardLoggerが使用される.Falseにすると、ロギングを無効にします.
    • 複数のロガーが提供され,そのロガーの save_dir プロパティが設定されていない場合,ローカルファイル(チェックポイント、プロファイラートレースなど)は,個々のロガーのlog_dir ではなく default_root_dir に保存されます
  • default_root_dir

    • logger/ckpt_callback が渡されない場合の、ログと重みのデフォルトパス.
    • デフォルト: os.getcwd(). s3://mybucket/path` や 'hdfs://path/' のようなリモートファイルのパスを指定することができる.
  • callbacks

    • コールバック コールバックまたはコールバックのリストを追加します.
    • checkpoint_callback。もし True ならば、チェックポイントを有効にする.
    • ... 非推奨:: v1.5
      • checkpoint_callback は v1.5 で非推奨となり、v1.7 で削除される予定である。代わりに enable_checkpointing を使用することを検討してください.
  • enable_checkpointing

    • True の場合、チェックポイントを有効にします
    • paramref:~pytorch_lightning.trainer.Trainer.callbacks にユーザ定義の ModelCheckpoint が存在しない場合、デフォルトの ModelCheckpoint コールバックが設定されます.
  • progress_bar_refresh_rate

    • プログレスバーを更新する頻度(単位:ステップ).値 0 はプログレスバーを無効にする. paramref:~Trainer.callbacksにカスタムのプログレスバーが渡された場合は、無視されます.
    • デフォルト:None, 環境(ターミナル、Google COLABなど)に応じて適切な値が選択されます.
    • ... deprecated:: v1.5 progress_bar_refresh_rate は v1.5 で廃止され、 v1.7 で削除される予定です。
      • 代わりに, :class:~pytorch_lightning.callbacks.progress.TQDMProgressBarrefresh_rate をつけて,Trainer の callbacks 引数に直接渡してください.プログレスバーを無効にするには,enable_progress_bar = False を Trainer に渡します.
  • enable_progress_bar

    • Whether to enable to progress bar by default.

2.1.3. trainer.fit について (ソース)

  • trainer.fit関数自体は以下のコードです(UnionやOptionalにある見慣れない型はpl/utilities/types.pyで定義されています).
def fit(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        ckpt_path: Optional[str] = None,
    ) -> None:
        self.training_type_plugin.model = model
        self._call_and_handle_interrupt(
            self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
        )
  • この中のtraining_type_pluginこちらで定義されています.
    @property
    def training_type_plugin(self) -> TrainingTypePlugin:
        return self._accelerator_connector.training_type_plugin
  • そして,その中の_accelerator_connectorで定義されています.そしてそれはAcceleratorConnectorを元にしていて,training_type_pluginが定義されています.また,これは,final_training_type_pluginで定義されています...(以下略).

  • コードが複雑で細かい部分は追えませんでしたが,公式doc(TrainingTypePlugin)と同じ様なものと考えて良さそうで,モデルの初期設定をした様なものなんだと思います(ここは完全に憶測です.時間がある時に正しく追いたいと思います).

  • そして,_call_and_handle_interruptではerrorの処理をしつつ(細かいところはパスします),第一引数のCallable(関数)を実装します.

  • そのため,(def fitより)self._fit_implに注目すれば挙動の理解を進めることができます.

2.1.4. self._fit_impl について (ソース)

self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
  • まず, TrainerFnを使って,trainerの状態を記述しているのがわかります.

  • そして,次に_data_conectorでデータをtrainerにリンクをする様です.

  • そして,_runを実装します.
    流れとしては

fr"""
             Lightning internal flow looks like this:
        {Trainer.fit} or {Trainer.test} or {Trainer.predict}  ||
                                |                             ||
                         spawn processes                      ||
           {self.training_type_plugin.setup_environment}      ||
                                |                             ||
                        setup accelerator                     ||
                           and strategy                       ||  LIGHTNING
                                |                             ||
                         {self.run_stage}                     ||  FLOW
                                |                             ||
                        {self._run_train}                     ||  DIRECTION
                     or {self._run_evaluate}                  ||
                     or {self._run_predict}                   ||
                                |                             ||
                             results                          \/
        This is used to guide readers to the core loops: train, test, predict.
        {self._run_predict} is the simplest to understand, use `Go to Definition` to read it :)
"""
  • まず,環境に関するsetupを行います.

  • 次にここrun_stageの実装を行います.これが肝で,_run_trainが実装されます.

  • 今回の章からcallbacksによって定義されたクラスの中のon_validation_endの呼ばれ方や学習の様子をどの様にコーディングするかを完璧とは言いませんが,確認することができました.

  • 次の章ではwandbloggerおよびloggerの使い方や設定方法について見ていきたいと思います.

3. WandbLoggerについて(C) (ソース)

3.1. __init__の引数について (ソース)

name: Display name for the run.
save_dir: Path where data is saved (wandb dir by default).
offline: Run offline (data can be streamed later to wandb servers).
id: Sets the version, mainly used to resume a previous run.
version: Same as id.
anonymous: Enables or explicitly disables anonymous logging.
project: The name of the project to which this run will belong.
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
    as W&B artifacts. `latest` and `best` aliases are automatically set.
    * if ``log_model == 'all'``, checkpoints are logged during training.
    * if ``log_model == True``, checkpoints are logged at the end of training, except when
        :paramref:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint.save_top_k` ``== -1``
        which also logs every checkpoint during training.
    * if ``log_model == False`` (default), no checkpoint is logged.
prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?