12
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

PyTorchのライブラリIgniteすげえ便利(GW第2弾)

はじめに

5月入って2日目です。
本日は、Igniteというライブラリを触っていました。これはNNの学習のためのライブラリで、LightiningCatalystと同じようなものです。前にLightningは少し触っていましたがあまり自分の中でヒットはせず、、、。ですが、Igniteは結構使えそうな気がします!興味を持った方は是非使ってみてください!

Igniteとは

IgniteはPyTorchのEcosystem Toolsにあり、ニューラルネットの学習のための高レベルライブラリです。Igniteを使用する利点は学習のコードをコンパクトに書けるところです。

$ pip install pytorch-ignite

でインストール可能です。

IgniteでMNIST

Igniteには、チュートリアルとしてMNISTのデータセットを使用したものがありました。今回は、そのコードを一部変更しています。

Igniteのポイント

以下のポイントが抑えられれば使いこなすことは簡単です!

  1. create_supervised_trainer() : trainerの定義

  2. create_supervised_evaluator() : evaluatorの定義

  3. @trainer.on([実行するタイミング]) : 各処理を定義した関数をこのデコレートする

  4. add_event_handler() : モデル保存やEarlyStoppingなどの設定

  5. tensorboardlogger : ログの吐き出し口

1.trainerの定義

trainer = create_supervised_trainer(model, optimizer, criterion, device=device)

trainerには、model, optimizer, criteriondeviceを渡します。

2. evaluatorの定義

metrics = {
            "accuracy" : Accuracy(), 
            "loss" : Loss(criterion)
            }

# evaluator
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

evaluatorには、model, metrics, deviceを渡します。

metricsには以下のようなものがあります。

  • Accuracy
  • Average
  • ConfusionMatrix
  • IoU()
  • mIoU()
  • Loss
  • Recall
  • Precision
  • RunningAverage などです。

3. @trainer.on([実行するタイミング])

ある関数を書いて@trainer.on([実行するタイミング])でデコレートします。エポックの終了時にはacuracyとlossを表示させたり、自分が好きなタイミングでどのような処理をするかを記述します。例をみた方が早いと思うので以下に例をあげます。

@trainer.on(Events.EPOCH_COMPLETED)
    def log_training_result(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        print("Training Results - Epoch[{}] Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(engine.state.epoch, metrics['accuracy'], metrics['loss']))

@trainer.on(Events.EPOCH_COMPLETED)は、ある1エポックが終了した際に実行されるという意味です。今回だと、1エポック終了時にaccuracyとlossを表示しています。

train_evaluator.run(train_loader)でevaluatorを実行させます。

実行するタイミングは、

  • COMPLETED
  • EPOCH_COMPLETED
  • EPOCH_STARTED
  • ITERATION_COMPLETED
  • ITERATION_STARTED

などあります。

stateは、デフォルトで以下の属性を持っています。

  • state.iteration
  • state.epoch
  • state.seed
  • state.dataloader
  • state.epoch_length
  • state.batch
  • state.output
  • state.metrics

4. add_event_handler()

イベントハンドラーを追加したいときに使います。

add_event_handler(event_name, handler, *args, **kwargs)

以下は、モデルを保存するときの例です。


handler = ModelCheckpoint('/tmp/models', 
                          'myprefix', 
                          n_saved=2, 
                          create_dir=True
                          )

trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})

今回は使っていませんので、詳しくは公式のドキュメントで確認してください。

5. TensorBoardLogger

TensorBoardは、trainやvalidation時のログやモデル、パラメータ、勾配を表示してくれます。

XXX.attach(trainer, log_handler, event_name)を指定します。
OutputHandlerのtagは、tensorboardで確認したときのタグになっています。

tb_logger = TensorboardLogger(log_dir=log_dir)


tb_logger.attach(trainer, 
                 log_handler=OutputHandler(tag="training", 
                                           output_transform=lambda loss:{"batchloss":loss},
                                           metric_names="all"),
                 event_name = Events.ITERATION_COMPLETED(every=100),)


tb_logger.attach(train_evaluator,
                 log_handler=OutputHandler(tag="training",
                                           metric_names=["loss", "accuracy"],
                                           another_engine=trainer),
                 event_name = Events.EPOCH_COMPLETED, )


tb_logger.attach(validation_evaluator,
                 log_handler=OutputHandler(tag="validation",
                                           metric_names=["loss", "accuracy"],
                                           another_engine=trainer),
                 event_name = Events.EPOCH_COMPLETED,)


OutputHandler以外にもいろんなlog_handlerがあるので気になる方は、チェックしてみてください。

6. コード全体

import sys
from argparse import ArgumentParser
import logging

import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss



from ignite.contrib.handlers.tensorboard_logger import *



LOG_INTERVAL = 10


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)



def get_data_loaders(train_batch_size, val_batch_size):
    data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

    train_loader = DataLoader(MNIST(download=True, root='.',transform=data_transform, train=True), 
                   batch_size=train_batch_size, shuffle=True)

    val_loader = DataLoader(MNIST(download=True, root='.',transform=data_transform, train=False), 
                   batch_size=val_batch_size, shuffle=False)

    return train_loader, val_loader



def train(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):

    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)

    # define the model, device, optimizer, loss
    model = Net()
    device = "cpu"
    model.to(device)
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()


    # define the trainer and evaluator engine
    # trainer
    trainer = create_supervised_trainer(model, optimizer, criterion, device=device)


    # define two metrics : accuracy and loss to compute on val dataset
    metrics = {
            "accuracy" : Accuracy(), 
            "loss" : Loss(criterion)
            }

    # evaluator
    train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
    validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)



    # When an epoch ends we want compute training and val metrics
    # attach two additional handlers to the trainer on epoch compute event
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_result(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        print("Training Results - Epoch[{}] Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(engine.state.epoch, metrics['accuracy'], metrics['loss']))


    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_result(engine):
        validation_evaluator.run(val_loader)
        metrics = validation_evaluator.state.metrics
        print("Validation Results - Epoch[{}] Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(engine.state.epoch, metrics['accuracy'], metrics['loss']))

    tb_logger = TensorboardLogger(log_dir=log_dir)


    tb_logger.attach(trainer, 
                     log_handler=OutputHandler(tag="training", 
                                               output_transform=lambda loss:{"batchloss":loss},
                                               metric_names="all"),
                     event_name = Events.ITERATION_COMPLETED(every=100),
                     )

    tb_logger.attach(train_evaluator,
                     log_handler=OutputHandler(tag="training",
                                               metric_names=["loss", "accuracy"],
                                               another_engine=trainer),
                     event_name = Events.EPOCH_COMPLETED,
                     )

    tb_logger.attach(validation_evaluator,
                     log_handler=OutputHandler(tag="validation",
                                               metric_names=["loss", "accuracy"],
                                               another_engine=trainer),
                     event_name = Events.EPOCH_COMPLETED,
                     )



    tb_logger.attach(
            trainer,
            log_handler=OptimizerParamsHandler(optimizer),
            event_name = Events.ITERATION_COMPLETED(every=100)
            )

    tb_logger.attach(trainer, 
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100)
                     )

    tb_logger.attach(trainer, 
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100)
                     )

    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))


    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    # kick everything off

    trainer.run(train_loader, max_epochs=epochs)
    tb_logger.close()


if __name__ == "__main__":

    parser = ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--val_batch_size", type=int, default=1000)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--momentum", type=float, default=0.5)
    parser.add_argument("--log_dir", type=str, default="tensorboard_logs")

    args = parser.parse_args()

    # setup engine logger

    # logger作成
    logger = logging.getLogger("ignite.engine.engine.Engine")
    handler = logging.StreamHandler()

    formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
    handler.setFormatter(formatter)
    # ロガーに追加
    logger.addHandler(handler)
    # ログレベルの設定
    logger.setLevel(logging.INFO)

    train(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum, args.log_dir)

TensorBoardのログを確認

今回のプログラムは、tensorboardでログを確認できるようになっています。

tensorboard --logdir='./tensorboard_logs'

以下のようなものが確認可能です。

Screen Shot 2020-05-02 at 12.56.06.png

Screen Shot 2020-05-02 at 12.56.29.png

終わりに

簡単にではありますが、Igniteについて見てみました。学習周りがかなりスッキリしたように思います。簡単な検証をまわすときや使える場面は多いかと思うので使っていきたいと思います。

参考文献

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
12
Help us understand the problem. What are the problem?