今月頭、PFNが研究開発で用いる深層学習フレームワークをChainerからPyTorchに移すことを発表しました。これを機に自分が使うフレームワークもPyTorchに移行させるつもりですが、Chainerの使い勝手が良かったところを置いていくのが名残惜しいです。特にTrainerとか。
そこで、ChainerのTrainerを、PyTorchに持っていくことにしました! こちらがそのライブラリです。
pytorch-trainer | PyTorch's Trainer like Chainer's Trainer
pipでインストールできます。
pip install git+https://github.com/Hiroshiba/pytorch-trainer
この記事では、Chainerをforkして作成したこのTrainerライブラリを紹介します。
特徴
MNISTを学習するサンプルコードを使って特徴を紹介していきます。
ChainerのTrainerとほとんど同じインターフェースで扱える
直感的に使えるようにするため、ChainerのTrainerと同じインターフェースにすることを目指しました。PyTorchのOptimizerがモデルを持てないので直接Trainerに与えること以外は、ほぼ同様に使えると思います。
# Iterator
train_iter = pytorch_trainer.iterators.SerialIterator(train, batchsize)
# Updater
updater = training.updaters.StandardUpdater(
train_iter, optimizer, model, device=device)
# Trainer
trainer = training.Trainer(updater, (epoch, 'epoch'), out=out)
Reporterも使える
Reporterも同じインターフェースで使えます。
class Classifier(nn.Module):
def __init__(self, predictor):
super(Classifier, self).__init__()
self.predictor = predictor
def forward(self, x, t):
y = self.predictor(x)
loss = F.nll_loss(y, t)
reporter.report({'loss': loss}, self)
acc = accuracy(y, t)
reporter.report({'accuracy': acc}, self)
return loss
Extensionも使える
Evaluatorを含むExtensionを頑張って引き継ぎました。Scheduler関連は、PyTorch側に別のものがあるため移植していません。
# Evaluator
trainer.extend(extensions.Evaluator(test_iter, model, device=device),
call_before_training=True)
# Snapshot
trainer.extend(extensions.snapshot(n_retains=1, autoload=autoload),
trigger=(frequency, 'epoch'))
# LogReport
trainer.extend(extensions.LogReport(), call_before_training=True)
# PlotReport
trainer.extend(
extensions.PlotReport(['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'),
call_before_training=True)
# ProgressBar
trainer.extend(extensions.ProgressBar())
保存・読込ができる
Trainerにあったserialize機能を全てPyTorch用に書き換えました。これにより、Trainerをまるごと保存して、あとから学習を再開することができます。
# Resume
trainer.load_state_dict(torch.load(resume_pth))
以上です。大体の学習タスクはTrainerで解決できると思います。インストールは以下のコマンドで可能です。
pip install git+https://github.com/Hiroshiba/pytorch-trainer
PyTorchのDataLoaderが使えなかったり、Schedulerが使えなかったりと若干の不便が残っています(2020/05/24追記 Schedulerを追加しました)。私はPyTorchにあまり詳しくないので、ちょっとわかるって方はぜひ気軽にPullRequestを送ってください!