14
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Chainer/CuPyAdvent Calendar 2019

Day 15

PyTorchでChainerのTrainerを使えるライブラリ作った

Last updated at Posted at 2019-12-15

 今月頭、PFNが研究開発で用いる深層学習フレームワークをChainerからPyTorchに移すことを発表しました。これを機に自分が使うフレームワークもPyTorchに移行させるつもりですが、Chainerの使い勝手が良かったところを置いていくのが名残惜しいです。特にTrainerとか。

 そこで、ChainerのTrainerを、PyTorchに持っていくことにしました! こちらがそのライブラリです。
 pytorch-trainer | PyTorch's Trainer like Chainer's Trainer

pipでインストールできます。

pip install git+https://github.com/Hiroshiba/pytorch-trainer

この記事では、Chainerをforkして作成したこのTrainerライブラリを紹介します。

特徴

 MNISTを学習するサンプルコードを使って特徴を紹介していきます。

ChainerのTrainerとほとんど同じインターフェースで扱える

直感的に使えるようにするため、ChainerのTrainerと同じインターフェースにすることを目指しました。PyTorchのOptimizerがモデルを持てないので直接Trainerに与えること以外は、ほぼ同様に使えると思います。

# Iterator
train_iter = pytorch_trainer.iterators.SerialIterator(train, batchsize)

# Updater
updater = training.updaters.StandardUpdater(
    train_iter, optimizer, model, device=device)

# Trainer
trainer = training.Trainer(updater, (epoch, 'epoch'), out=out)

Reporterも使える

 Reporterも同じインターフェースで使えます。

class Classifier(nn.Module):
    def __init__(self, predictor):
        super(Classifier, self).__init__()
        self.predictor = predictor

    def forward(self, x, t):
        y = self.predictor(x)
        loss = F.nll_loss(y, t)
        reporter.report({'loss': loss}, self)
        acc = accuracy(y, t)
        reporter.report({'accuracy': acc}, self)
        return loss

Extensionも使える

 Evaluatorを含むExtensionを頑張って引き継ぎました。Scheduler関連は、PyTorch側に別のものがあるため移植していません。

# Evaluator
trainer.extend(extensions.Evaluator(test_iter, model, device=device),
               call_before_training=True)

# Snapshot
trainer.extend(extensions.snapshot(n_retains=1, autoload=autoload),
               trigger=(frequency, 'epoch'))

# LogReport
trainer.extend(extensions.LogReport(), call_before_training=True)

# PlotReport
trainer.extend(
    extensions.PlotReport(['main/loss', 'validation/main/loss'],
                          'epoch', file_name='loss.png'),
    call_before_training=True)

# ProgressBar
trainer.extend(extensions.ProgressBar())

保存・読込ができる

 Trainerにあったserialize機能を全てPyTorch用に書き換えました。これにより、Trainerをまるごと保存して、あとから学習を再開することができます。

# Resume
trainer.load_state_dict(torch.load(resume_pth))

 以上です。大体の学習タスクはTrainerで解決できると思います。インストールは以下のコマンドで可能です。

pip install git+https://github.com/Hiroshiba/pytorch-trainer

 PyTorchのDataLoaderが使えなかったり、Schedulerが使えなかったりと若干の不便が残っています(2020/05/24追記 Schedulerを追加しました)。私はPyTorchにあまり詳しくないので、ちょっとわかるって方はぜひ気軽にPullRequestを送ってください!

14
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?