Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

Chainer で手書き数字認識(MNIST)

More than 3 years have passed since last update.

概要

定番の手書き数字データセット MNIST を Chainer を使用して3層のニューラルネットワークで学習してみました。

3層ニューラルネットワーク

Chainer では、MNIST のデータを自動的にダウンロードして使用できる便利なメソッドが用意されています。それを使って、3層ニューラルネットワーク(2つある隠れ層のノードはそれぞれ 50)で学習を行い、正答率(accuracy)を測定しました。下のコードは、Chainer の MNIST サンプル を下敷きにしていますが、大幅に変更を加えてあります。

コード

neural_net.py
import numpy as np
import chainer
from chainer import Chain, Variable
import chainer.functions as F
import chainer.links as L

class NeuralNet(chainer.Chain):
    def __init__(self, n_units, n_out):
        super().__init__(
            l1=L.Linear(None, n_units),
            l2=L.Linear(n_units, n_units),
            l3=L.Linear(n_units, n_out),
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

def check_accuracy(model, xs, ts):
    ys = model(xs)
    loss = F.softmax_cross_entropy(ys, ts)
    ys = np.argmax(ys.data, axis=1)
    cors = (ys == ts)
    num_cors = sum(cors)
    accuracy = num_cors / ts.shape[0]
    return accuracy, loss

def main():
    model = NeuralNet(50, 10)

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    train, test = chainer.datasets.get_mnist()
    xs, ts = train._datasets
    txs, tts = test._datasets

    bm = 100

    for i in range(100):

        for j in range(600):
            model.zerograds()
            x = xs[(j * bm):((j + 1) * bm)]
            t = ts[(j * bm):((j + 1) * bm)]
            t = Variable(np.array(t, "i"))
            y = model(x)
            loss = F.softmax_cross_entropy(y, t)
            loss.backward()
            optimizer.update()

        accuracy_train, loss_train = check_accuracy(model, xs, ts)
        accuracy_test, _           = check_accuracy(model, txs, tts)

        print("Epoch %d loss(train) = %f, accuracy(train) = %f, accuracy(test) = %f" % (i + 1, loss_train.data, accuracy_train, accuracy_test))

if __name__ == '__main__':
    main()


実行結果

Epoch 1 loss(train) = 0.242552, accuracy(train) = 0.926683, accuracy(test) = 0.924800
Epoch 2 loss(train) = 0.175040, accuracy(train) = 0.946167, accuracy(test) = 0.942000
Epoch 3 loss(train) = 0.133406, accuracy(train) = 0.959050, accuracy(test) = 0.954400
...
Epoch 98 loss(train) = 0.002298, accuracy(train) = 0.999267, accuracy(test) = 0.971300
Epoch 99 loss(train) = 0.002876, accuracy(train) = 0.998917, accuracy(test) = 0.972900
Epoch 100 loss(train) = 0.003336, accuracy(train) = 0.998917, accuracy(test) = 0.973500

解説

1エポック(epoch)は、600回のイタレーション(iteration)で構成されています。100エポック後に、訓練(train)データに対する正答率は、99.9%、テストデータに対する正答率は97.3%となっています。Optimizer は SGD も試しましたが、Adam のほうがはるかに学習速度が速かったです。上記のコードの他に、各層のノードの数を少し増やしたり、層を1つ増やして4層にしたりもしてみましたが、正答率に大きな違いはありませんでした。

感想

「ゼロから作る Deep Learning」を読んで5ヶ月が経ちましたが、ようやく Chainer で実装するところまでたどり着きました。当たり前ですが、やはりフレームワークを作ると非常にすっきりと記述できますね。今回のコードを記述するにあたり、Chainer の Link, Chain, Optimizer クラスあたりのソースコードを読み込みました。とても簡潔にかかれており、読めばちゃんと理解できるのには感動しました。徐々に Chainer がわかってきた感があります。

参考文献

Chainer 本家ドキュメント

elm200
ソフトウェアエンジニア。Python と機械学習。
http://elm200.hatenablog.com/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away