15
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

chainer 1.11.0以降のmnistを解説

Last updated at Posted at 2016-10-13

#はじめに
chainerが1.11.0になってから結構変わっていたので、自分なりの理解を書きます。
できるだけ、初めてpythonとchainerをやってみる人にも分かるようします(つもりです)。

コードはここ
サンプルの中のtrain_mnist.pyというファイルです。

#MNIST
mnistとは28x28のサイズの数字が書かれた画像のデータセットです。
機械学習で入門用としてよく使われるものです。

#ネットワーク

class MLP(chainer.Chain):
    def __init__(self, n_in, n_units, n_out):
        super(MLP, self).__init__(
            l1=L.Linear(n_in, n_units),
            l2=L.Linear(n_units, n_units), 
            l3=L.Linear(n_units, n_out), 
        )
 
    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

ネットワークの定義では__init__の方で、使う層を定義する。
今回は、

  • 入力が_n_in_、出力が_n_units_の全結合層 l1
  • 入力が_n_units_、出力が_n_units_の全結合層 l2
  • 入力が_n_units_、出力が_n_outs_の全結合層 l3

__call__では、具体的なネットワークを記述する。
今回は、l1とl2の出力に_relu_という活性化関数を使っている。

#Parser
一応parserのことも

    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                help='Number of units')
    args = parser.parse_args()
 
    print('GPU: {}'.format(args.gpu))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

parserはpythonをコマンドで実行するときにパラメーターを設定しやすくしてくれる便利なやつです。
ターミナルで例えば以下のように実行すると

> $ python train_mnist.py -g 0 -u 100
GPU: 0
# unit: 100
# Minibatch-size: 100
# epoch: 20

と表示されます。
指定してないepochなどは、_default_で初期化されている値になります。
自分で追加したいときは

add_argument('後で呼ぶための名前', '-ターミナルでの指定方法', 数字ならtype=int, 指定のなかった場合のdefault値)

のような形で利用できます。

#データの初期化
chainerではtrainデータと、testデータを用意します。

train, test = chainer.datasets.get_mnist()

これはmnistで使われるデータを取ってきてtrainとtestに入れてるだけです。
なかがどんなカタチになっているかというと、一つの行(train[0])に
[[.234809284, .324039284, .34809382 …. .04843098], 3]
というように、左に入力値と右にその答え(ラベル値)がセットで入っています。
また、chainerではtrainで学習して、testで試してみて正解率を見ていく感じになります。

#イテレータ
従来では自分でfor分を用意して何回も回して学習させてとやっていたのですが、1.11.0からは上の train のようにデータを入れておいて、これを使いますと言ってあげればfor分を書く必要はありません。

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

もうこれでいいらしい。おまじない感ある

#Trainer
trainerというものが追加されてこれがもうほぼほぼ勝手にいろいろやってくれるそう。
問題集と答えを家庭教師に渡して、子供をよろしくお願いします。的な
自分で勉強を教えてたのを、家庭教師に任せるイメージ(合っているかはわからない)

まず、trainerを設定する。

updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'),

この train_iter (問題集)を使って、この optimizer (勉強方法)で最適化してもらって、
それを _epoch _ (何周)回してください。

以下については必ずしも必要なわけではないものもある。

trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    # これはいる。test_iterを使ってepochごとに評価してる(と思う)
trainer.extend(extensions.dump_graph('main/loss'))
    # ネットワークの形をグラフで表示できるようにdot形式で保存する。
trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))
    # epochごとのtrainerの情報を保存する。それを読み込んで、途中から再開などができる。これけすと結構早くなったりした?
trainer.extend(extensions.LogReport())
    # epochごとにlogをだす
trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy']))
    # logで出す情報を指定する。
trainer.extend(extensions.ProgressBar())
    # 今全体と、epochごとでどのぐらい進んでいるかを教えてくれる。
 
trainer.run()
    # trainerをいろいろ設定した後、これをやって実際に実行する。これは必須

_main/loss_は答えとの差の大きさ。
_mian/accuracy_は正解率。
validation/main/accuracyが何を指しているかは、よくわかりません。(誰かコメントしていただけると...)

なんでここ説明して、あそこ説明しないのとかになると思うけどそれはまだ良くわかって無いからだったり

実際にどう弄ったかなどは、まだ上げる予定です。

15
19
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?