はじめに
5月入って2日目です。
本日は、Igniteというライブラリを触っていました。これはNNの学習のためのライブラリで、LightiningやCatalystと同じようなものです。前にLightningは少し触っていましたがあまり自分の中でヒットはせず、、、。ですが、Igniteは結構使えそうな気がします!興味を持った方は是非使ってみてください!
#Igniteとは
IgniteはPyTorchのEcosystem Toolsにあり、ニューラルネットの学習のための高レベルライブラリです。Igniteを使用する利点は学習のコードをコンパクトに書けるところです。
$ pip install pytorch-ignite
でインストール可能です。
#IgniteでMNIST
Igniteには、チュートリアルとしてMNISTのデータセットを使用したものがありました。今回は、そのコードを一部変更しています。
###Igniteのポイント
以下のポイントが抑えられれば使いこなすことは簡単です!
-
create_supervised_trainer()
: trainerの定義 -
create_supervised_evaluator()
: evaluatorの定義 -
@trainer.on([実行するタイミング])
: 各処理を定義した関数をこのデコレートする -
add_event_handler()
: モデル保存やEarlyStoppingなどの設定 -
tensorboardlogger : ログの吐き出し口
###1.trainerの定義
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
trainerには、model
, optimizer
, criterion
とdevice
を渡します。
###2. evaluatorの定義
metrics = {
"accuracy" : Accuracy(),
"loss" : Loss(criterion)
}
# evaluator
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
evaluatorには、model
, metrics
, device
を渡します。
metricsには以下のようなものがあります。
- Accuracy
- Average
- ConfusionMatrix
- IoU()
- mIoU()
- Loss
- Recall
- Precision
- RunningAverage
などです。
###3. @trainer.on([実行するタイミング])
ある関数を書いて@trainer.on([実行するタイミング])
でデコレートします。エポックの終了時にはacuracyとlossを表示させたり、自分が好きなタイミングでどのような処理をするかを記述します。例をみた方が早いと思うので以下に例をあげます。
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_result(engine):
train_evaluator.run(train_loader)
metrics = train_evaluator.state.metrics
print("Training Results - Epoch[{}] Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, metrics['accuracy'], metrics['loss']))
**@trainer.on(Events.EPOCH_COMPLETED)**は、ある1エポックが終了した際に実行されるという意味です。今回だと、1エポック終了時にaccuracyとlossを表示しています。
train_evaluator.run(train_loader)
でevaluatorを実行させます。
実行するタイミングは、
- COMPLETED
- EPOCH_COMPLETED
- EPOCH_STARTED
- ITERATION_COMPLETED
- ITERATION_STARTED
などあります。
stateは、デフォルトで以下の属性を持っています。
- state.iteration
- state.epoch
- state.seed
- state.dataloader
- state.epoch_length
- state.batch
- state.output
- state.metrics
#4. add_event_handler()
イベントハンドラーを追加したいときに使います。
add_event_handler(event_name, handler, *args, **kwargs)
以下は、モデルを保存するときの例です。
handler = ModelCheckpoint('/tmp/models',
'myprefix',
n_saved=2,
create_dir=True
)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})
今回は使っていませんので、詳しくは公式のドキュメントで確認してください。
#5. TensorBoardLogger
TensorBoardは、trainやvalidation時のログやモデル、パラメータ、勾配を表示してくれます。
XXX.attach(trainer, log_handler, event_name)
を指定します。
OutputHandlerのtagは、tensorboardで確認したときのタグになっています。
tb_logger = TensorboardLogger(log_dir=log_dir)
tb_logger.attach(trainer,
log_handler=OutputHandler(tag="training",
output_transform=lambda loss:{"batchloss":loss},
metric_names="all"),
event_name = Events.ITERATION_COMPLETED(every=100),)
tb_logger.attach(train_evaluator,
log_handler=OutputHandler(tag="training",
metric_names=["loss", "accuracy"],
another_engine=trainer),
event_name = Events.EPOCH_COMPLETED, )
tb_logger.attach(validation_evaluator,
log_handler=OutputHandler(tag="validation",
metric_names=["loss", "accuracy"],
another_engine=trainer),
event_name = Events.EPOCH_COMPLETED,)
OutputHandler以外にもいろんなlog_handlerがあるので気になる方は、チェックしてみてください。
#6. コード全体
import sys
from argparse import ArgumentParser
import logging
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.contrib.handlers.tensorboard_logger import *
LOG_INTERVAL = 10
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(MNIST(download=True, root='.',transform=data_transform, train=True),
batch_size=train_batch_size, shuffle=True)
val_loader = DataLoader(MNIST(download=True, root='.',transform=data_transform, train=False),
batch_size=val_batch_size, shuffle=False)
return train_loader, val_loader
def train(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
# define the model, device, optimizer, loss
model = Net()
device = "cpu"
model.to(device)
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.CrossEntropyLoss()
# define the trainer and evaluator engine
# trainer
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
# define two metrics : accuracy and loss to compute on val dataset
metrics = {
"accuracy" : Accuracy(),
"loss" : Loss(criterion)
}
# evaluator
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
# When an epoch ends we want compute training and val metrics
# attach two additional handlers to the trainer on epoch compute event
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_result(engine):
train_evaluator.run(train_loader)
metrics = train_evaluator.state.metrics
print("Training Results - Epoch[{}] Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, metrics['accuracy'], metrics['loss']))
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_result(engine):
validation_evaluator.run(val_loader)
metrics = validation_evaluator.state.metrics
print("Validation Results - Epoch[{}] Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, metrics['accuracy'], metrics['loss']))
tb_logger = TensorboardLogger(log_dir=log_dir)
tb_logger.attach(trainer,
log_handler=OutputHandler(tag="training",
output_transform=lambda loss:{"batchloss":loss},
metric_names="all"),
event_name = Events.ITERATION_COMPLETED(every=100),
)
tb_logger.attach(train_evaluator,
log_handler=OutputHandler(tag="training",
metric_names=["loss", "accuracy"],
another_engine=trainer),
event_name = Events.EPOCH_COMPLETED,
)
tb_logger.attach(validation_evaluator,
log_handler=OutputHandler(tag="validation",
metric_names=["loss", "accuracy"],
another_engine=trainer),
event_name = Events.EPOCH_COMPLETED,
)
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer),
event_name = Events.ITERATION_COMPLETED(every=100)
)
tb_logger.attach(trainer,
log_handler=WeightsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=100)
)
tb_logger.attach(trainer,
log_handler=WeightsHistHandler(model),
event_name=Events.EPOCH_COMPLETED(every=100)
)
tb_logger.attach(trainer,
log_handler=GradsScalarHandler(model),
event_name=Events.ITERATION_COMPLETED(every=100))
tb_logger.attach(trainer,
log_handler=GradsHistHandler(model),
event_name=Events.ITERATION_COMPLETED(every=100))
# kick everything off
trainer.run(train_loader, max_epochs=epochs)
tb_logger.close()
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--val_batch_size", type=int, default=1000)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--momentum", type=float, default=0.5)
parser.add_argument("--log_dir", type=str, default="tensorboard_logs")
args = parser.parse_args()
# setup engine logger
# logger作成
logger = logging.getLogger("ignite.engine.engine.Engine")
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
handler.setFormatter(formatter)
# ロガーに追加
logger.addHandler(handler)
# ログレベルの設定
logger.setLevel(logging.INFO)
train(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum, args.log_dir)
TensorBoardのログを確認
今回のプログラムは、tensorboardでログを確認できるようになっています。
tensorboard --logdir='./tensorboard_logs'
以下のようなものが確認可能です。
#終わりに
簡単にではありますが、Igniteについて見てみました。学習周りがかなりスッキリしたように思います。簡単な検証をまわすときや使える場面は多いかと思うので使っていきたいと思います。
#参考文献