LoginSignup
2
5

More than 5 years have passed since last update.

Chainerの中身を探ってみよう ~独自の拡張機能を追加するには~

Last updated at Posted at 2017-11-01

はじめに

※この記事は,「大規模ソフトウェア実験」の報告書の一部として作成しますが,Pythonライブラリ(特にChainer)の探り方がどなたにでもわかるように紹介できたらと思います。

今回実験で扱う題材は,機械学習フレームワークとして知られるChainerです。
Chainerにどう手を加えるかは自由ですが,ここでは学習ループの間に任意の処理を追加することを考えます。
例えば,エポック毎にトレーニングロス等の具合をグラフで表示してくれたり,現在のウェイトを可視化してくれたりする機能があったら面白そうやん。

さて,以下では,次のことを説明していきます。
- Chainerのインストール方法
- trainerとextensionについて
- 機能の追加方法

Chainerを直接インストールする

機械学習フレームワークとして知られるChainerは,Pythonのライブラリとして提供されています。
なので,通常用いる場合はpipでインストールしてすぐに使うことができます。

$ pip install chainer

しかしここでは,中身を触れるようにしたいので,githubからソースコードをcloneし,直接インストールします。(要するに,pipを使わずにインストールします。)
では作業スペースをworkspaceとし,以下のようにChainerをcloneしてきます。

$ mkdir workspace
$ cd workspace
$ git clone https://github.com/chainer/chainer
$ ls
chainer
$

次に,chainer/にあるsetup.pyを実行します。

$ cd chainer
$ python setup.py install

ここで,

$ python
>>> import chainer
>>>

chainerが無事インポートできればOKです。
また,cloneしてきたフォルダにはMNISTを学習するのサンプルプログラムが含まれているので,それを実行して確認してみるのもよいでしょう。(場所はworkspace/chainer/examples/mnist/train_mnist.py)

$ python workspace/chainer/examples/mnist/train_mnist.py
GPU: -1
# unit: 1000
# Minibatch-size: 100
# epoch: 20

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           0.192617    0.100486              0.9426         0.9685                    26.3924

さてこれから,workspace/chainerの中身を変えていきます。中身を変えたら,再びこのようにインストールしてみて下さい。その変更が反映されるようになります。

Chainerの中身を覗いてみよう

Chainerの中身は,まさにworkspace/chainer/chainer以下がそうなので,自由にソースコードを見てみてください。
例えば,activation.pyやoptimizer.pyなど,ファイル名で内容が容易に想像できるものがたくさんあります。

今回は,学習ループに着目するので,そこを担うtrainer.pyを見てみます。
先のMNISTサンプルプログラムでもわかるように,trainer.pyはメインの学習ループを担っています。
trainingをインポートし,trainer.Trainerを生成,.run()で学習ループを実行します。

workspace/chainer/examples/mnist/train_mnist.py
10 from chainer import training

...

77 trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

...

119    # Run the training
120    trainer.run()

trainer.pyの場所は,workspace/chainer/chainer/training/trainer.pyにあります。
269行目にrun()関数があるのを確認できます。

workspace/chainer/chainer/training/trainer.py(L269)
269    def run(self, show_loop_exception_msg=True):
270        """Executes the training loop.
271
272        This method is the core of ``Trainer``. It executes the whole loop of
273        training the models.
274
275        Note that this method cannot run multiple times for one trainer object.
276
277        """

さていよいよこれに変更を加えてみましょう。
さらに見ていくと,305行目にメインの学習ループがあります。

workspace/chainer/chainer/training/trainer.py(L305)
305        # main training loop
306        try:
307            while not stop_trigger(self):
308                self.observation = {}
309                with reporter.scope(self.observation):
310                    update()
311                    for name, entry in extensions:
312                        if entry.trigger(self):
313                            entry.extension(self)

試しにここにprint文を加えてみましょう。

workspace/chainer/chainer/training/trainer.py(L305)
305        # main training loop
306        try:
307            while not stop_trigger(self):
308                self.observation = {}
309                print('aaaa') # here
310                with reporter.scope(self.observation):
311                    update()
312                    for name, entry in extensions:
313                        if entry.trigger(self):
314                            entry.extension(self)

変更を保存したら,再びインストール作業を行います。

$ cd workspace/chainer
$ python setup.py install

もう一度サンプルプログラムを実行してみましょう。

$ python workspace/chainer/examples/mnist/train_mnist.py
GPU: -1
# unit: 1000
# Minibatch-size: 100
# epoch: 20

aaaa
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
aaaa
aaaa

このように,変更が反映されているはずです!

拡張機能の追加はextensionで行う

実は上で見たtrainerには,学習時の拡張機能の追加が簡単にできるような仕組みが予め備わっています。
MNISTサンプルコードの79行目を見てみましょう。

train_mnist.py
11 from chainer.training import extensions

...

79    # Evaluate the model with the test dataset for each epoch
80    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
81
82    # Dump a computational graph from 'loss' variable at the first iteration
83    # The "main" refers to the target link of the "main" optimizer.
84    trainer.extend(extensions.dump_graph('main/loss'))
85
86    # Take a snapshot for each specified epoch
87    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
88    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))
89
90    # Write a log of evaluation statistics for each epoch
91    trainer.extend(extensions.LogReport())

ここが,trainer.extendという文言で拡張機能を追加している部分になります。
上で挙げた79~91行目の部分では,4つの拡張機能を追加していることになります。
それらは,
- Evaluator
- dump_graph
- snapshot
- LogReport
です。
これらはデフォルトで用意されている拡張機能であり,workspace/chainer/chainer/training/extensions/以下にあるので確認してみて下さい。
例えばevaluator.pyを見てみましょう。

workspace/chainer/chainer/training/extensions/evaluator.py
14 class Evaluator(extension.Extension):
15
16     """Trainer extension to evaluate models on a validation set.

コメント文から,モデルの評価を行う拡張機能であることがわかります。この拡張機能を使えば,エポック毎にモデルの評価が可能になります。

また,Evaluatorクラスはextension.Extensionクラスを継承していることがわかります。
これも覗いてみましょう。

workspace/chainer/chainer/training/extension.py
9 class Extension(object):
10 
11     """Base class of trainer extensions.

Extensionクラスが,拡張機能に共通な要素を用意してくれているようです。

したがって,独自の拡張機能を作りたい場合は,このExtensionクラスを継承して独自のクラスを作成すればよいということがわかります。作成したものは,デフォルトの拡張機能たち(evaluator.py等)と同じ場所に置いておくのがよさそうです。
そしてそれを使用するには,サンプルと同様にextendを呼び出してそこに渡せばOKだと思います。

また,extension.pyには,与えられた関数を拡張機能の形にして返してくれるmake_extension関数があることもわかりました。

workspace/chainer/chainer/training/extension.py
111 def make_extension(trigger=None, default_name=None, priority=None,
112                   finalizer=None, initializer=None, **kwargs):
113     """Decorator to make given functions into trainer extensions.

これを用いれば,もっと簡単に独自の拡張機能を追加することができそうですね。
関数を定義して,それをこいつに渡せばよいわけです。

まとめ

以上では,Chainerのインストール方法,中身の解説・探り方,学習時の拡張機能の追加方法について説明してきました。
その結果,Chainerには拡張機能が追加しやすいような仕組みが予め備わっていることがわかりました。
追加したい機能を単にソースにそのまま書き込むのではなく,この仕組を用いたほうが拡張機能としてキレイに追加できると思います。
それは,
1. Extensionクラスを継承して新たなクラスとして追加実装する
2. make_extensionを使って関数を渡すことで追加する
です。
試してみて下さい。

2
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
5