9
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

updated at

Organization

PyTorchでChainerのTrainerを使えるライブラリ作った

 今月頭、PFNが研究開発で用いる深層学習フレームワークをChainerからPyTorchに移すことを発表しました。これを機に自分が使うフレームワークもPyTorchに移行させるつもりですが、Chainerの使い勝手が良かったところを置いていくのが名残惜しいです。特にTrainerとか。

 そこで、ChainerのTrainerを、PyTorchに持っていくことにしました! こちらがそのライブラリです。
 pytorch-trainer | PyTorch's Trainer like Chainer's Trainer

pipでインストールできます。

pip install git+https://github.com/Hiroshiba/pytorch-trainer

この記事では、Chainerをforkして作成したこのTrainerライブラリを紹介します。

特徴

 MNISTを学習するサンプルコードを使って特徴を紹介していきます。

ChainerのTrainerとほとんど同じインターフェースで扱える

直感的に使えるようにするため、ChainerのTrainerと同じインターフェースにすることを目指しました。PyTorchのOptimizerがモデルを持てないので直接Trainerに与えること以外は、ほぼ同様に使えると思います。

# Iterator
train_iter = pytorch_trainer.iterators.SerialIterator(train, batchsize)

# Updater
updater = training.updaters.StandardUpdater(
    train_iter, optimizer, model, device=device)

# Trainer
trainer = training.Trainer(updater, (epoch, 'epoch'), out=out)

Reporterも使える

 Reporterも同じインターフェースで使えます。

class Classifier(nn.Module):
    def __init__(self, predictor):
        super(Classifier, self).__init__()
        self.predictor = predictor

    def forward(self, x, t):
        y = self.predictor(x)
        loss = F.nll_loss(y, t)
        reporter.report({'loss': loss}, self)
        acc = accuracy(y, t)
        reporter.report({'accuracy': acc}, self)
        return loss

Extensionも使える

 Evaluatorを含むExtensionを頑張って引き継ぎました。Scheduler関連は、PyTorch側に別のものがあるため移植していません。

# Evaluator
trainer.extend(extensions.Evaluator(test_iter, model, device=device),
               call_before_training=True)

# Snapshot
trainer.extend(extensions.snapshot(n_retains=1, autoload=autoload),
               trigger=(frequency, 'epoch'))

# LogReport
trainer.extend(extensions.LogReport(), call_before_training=True)

# PlotReport
trainer.extend(
    extensions.PlotReport(['main/loss', 'validation/main/loss'],
                          'epoch', file_name='loss.png'),
    call_before_training=True)

# ProgressBar
trainer.extend(extensions.ProgressBar())

保存・読込ができる

 Trainerにあったserialize機能を全てPyTorch用に書き換えました。これにより、Trainerをまるごと保存して、あとから学習を再開することができます。

# Resume
trainer.load_state_dict(torch.load(resume_pth))

 以上です。大体の学習タスクはTrainerで解決できると思います。インストールは以下のコマンドで可能です。

pip install git+https://github.com/Hiroshiba/pytorch-trainer

 PyTorchのDataLoaderが使えなかったり、Schedulerが使えなかったりと若干の不便が残っています(2020/05/24追記 Schedulerを追加しました)。私はPyTorchにあまり詳しくないので、ちょっとわかるって方はぜひ気軽にPullRequestを送ってください!

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
9
Help us understand the problem. What are the problem?