17
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Chainer で手書き数字認識(MNIST)

Last updated at Posted at 2017-04-01

概要

定番の手書き数字データセット MNIST を Chainer を使用して3層のニューラルネットワークで学習してみました。

3層ニューラルネットワーク

Chainer では、MNIST のデータを自動的にダウンロードして使用できる便利なメソッドが用意されています。それを使って、3層ニューラルネットワーク(2つある隠れ層のノードはそれぞれ 50)で学習を行い、正答率(accuracy)を測定しました。下のコードは、Chainer の MNIST サンプル を下敷きにしていますが、大幅に変更を加えてあります。

コード

neural_net.py
import numpy as np
import chainer
from chainer import Chain, Variable
import chainer.functions as F
import chainer.links as L

class NeuralNet(chainer.Chain):
    def __init__(self, n_units, n_out):
        super().__init__(
            l1=L.Linear(None, n_units),
            l2=L.Linear(n_units, n_units),
            l3=L.Linear(n_units, n_out),
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

def check_accuracy(model, xs, ts):
    ys = model(xs)
    loss = F.softmax_cross_entropy(ys, ts)
    ys = np.argmax(ys.data, axis=1)
    cors = (ys == ts)
    num_cors = sum(cors)
    accuracy = num_cors / ts.shape[0]
    return accuracy, loss

def main():
    model = NeuralNet(50, 10)

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    train, test = chainer.datasets.get_mnist()
    xs, ts = train._datasets
    txs, tts = test._datasets

    bm = 100

    for i in range(100):

        for j in range(600):
            model.zerograds()
            x = xs[(j * bm):((j + 1) * bm)]
            t = ts[(j * bm):((j + 1) * bm)]
            t = Variable(np.array(t, "i"))
            y = model(x)
            loss = F.softmax_cross_entropy(y, t)
            loss.backward()
            optimizer.update()

        accuracy_train, loss_train = check_accuracy(model, xs, ts)
        accuracy_test, _           = check_accuracy(model, txs, tts)

        print("Epoch %d loss(train) = %f, accuracy(train) = %f, accuracy(test) = %f" % (i + 1, loss_train.data, accuracy_train, accuracy_test))

if __name__ == '__main__':
    main()


実行結果

Epoch 1 loss(train) = 0.242552, accuracy(train) = 0.926683, accuracy(test) = 0.924800
Epoch 2 loss(train) = 0.175040, accuracy(train) = 0.946167, accuracy(test) = 0.942000
Epoch 3 loss(train) = 0.133406, accuracy(train) = 0.959050, accuracy(test) = 0.954400
...
Epoch 98 loss(train) = 0.002298, accuracy(train) = 0.999267, accuracy(test) = 0.971300
Epoch 99 loss(train) = 0.002876, accuracy(train) = 0.998917, accuracy(test) = 0.972900
Epoch 100 loss(train) = 0.003336, accuracy(train) = 0.998917, accuracy(test) = 0.973500

解説

1エポック(epoch)は、600回のイタレーション(iteration)で構成されています。100エポック後に、訓練(train)データに対する正答率は、99.9%、テストデータに対する正答率は97.3%となっています。Optimizer は SGD も試しましたが、Adam のほうがはるかに学習速度が速かったです。上記のコードの他に、各層のノードの数を少し増やしたり、層を1つ増やして4層にしたりもしてみましたが、正答率に大きな違いはありませんでした。

感想

「ゼロから作る Deep Learning」を読んで5ヶ月が経ちましたが、ようやく Chainer で実装するところまでたどり着きました。当たり前ですが、やはりフレームワークを作ると非常にすっきりと記述できますね。今回のコードを記述するにあたり、Chainer の Link, Chain, Optimizer クラスあたりのソースコードを読み込みました。とても簡潔にかかれており、読めばちゃんと理解できるのには感動しました。徐々に Chainer がわかってきた感があります。

参考文献

Chainer 本家ドキュメント

17
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?