LoginSignup
768
771

More than 5 years have passed since last update.

手書きひらがなの認識で99.78%の精度をディープラーニングで

Posted at

手書きひらがなの認識、教科書通りの畳み込みニューラルネットワーク(ディープラーニング)で、99.78%の精度が出ました。教科書通りである事が(独自性がない事が)逆に読む方・書く方にメリットがありそうなので、Qiitaで記事にします。

ソースコード

ソースコードは https://github.com/yukoba/CnnJapaneseCharacter です。

いきさつ

いきさつは、友人と手書きひらがなの認識の雑談をFacebookでしていて、ググったら、この2つが見つかりました。

どちらも2016年3月に書かれた物ですが、スタンフォード大学の学生さんの方が先かな。

スタンフォード大学の学生さんのレポートでは、

  • ひらがな:96.50%
  • カタカナ:98.19%
  • 漢字:99.64%

と、なっていました。

別の友人の分析によると、漢字の方が手がかりが多いから簡単なんじゃないかという予想でした。この記事では、一番精度の悪い、ひらがなを99.78%に上げたという話をします。

データセット

データは皆さん、産総研の「ETL文字データベース」の ETL8G を使っていて、僕もそれを使ってます。具体的な手書き文字を観たい方は http://etlcdb.db.aist.go.jp/?page_id=2461 からどうぞ。

データは 128x127 pxで、4ビットのグレースケール画像で、160人分です。ひらがなは72種。

畳み込みニューラルネットワーク

畳み込みニューラルネットワーク(ディープラーニング)とはなんぞや、という基本的な話は tawago さんの http://qiita.com/tawago/items/931bea2ff6d56e32d693 をご覧ください。あと、O'Reilly Japan の書籍「ゼロから作るDeep Learning - Pythonで学ぶディープラーニングの理論と実装」も入門書として良い感じでした(立ち読みしかしてません、ごめんなさい)。

ライブラリ

ニューラルネットワークのライブラリとしては、今回は、

を使いました。低レイヤーとしては TensorFlow でも動くようにコードを書きました。

プログラミング言語は Python 3 です。

改善点

で、tawago さんの 95.04% やスタンフォード大学の学生さんの 96.50% から何を変えたかという話。本当に基礎的な事しかやってないです。

反復回数

まずスタンフォード大学の学生さん、CPUでやったらしく、計算回数不足だったので、Amazon EC2 の GPU を使いました。エポック数(反復回数)を40回から400回に増やしました。

機械学習は訓練データと評価データに分けます。訓練データに対して学習させるのですが、確率的勾配降下法は、乱数を使いますし、原理的に、細かく増えたり減ったりガタガタに推移するのですが、その話とは別の話として、学習結果を評価データに対して適用した時に、ある所まで改善し、ある所から悪くなる事が良くあるのですが、これを過学習といいます。

反復回数(エポック数)は理想としては過学習が始まる所までやるべきで(早期終了)、今回は300~400回目くらいで過学習が始まるようなので(真面目に確認してないです)、400回にしてあります。

訓練データと評価データの分割数は8:2です。これはスタンフォード大学の学生さんに従いました。

モデル

モデルは俗に「VGG風」と呼ばれるやつを使いました。2014年9月に Very Deep Convolutional Networks for Large-Scale Image Recognition としてオックスフォード大学の人が発表した物です。これを元にして書かれた Keras のサンプルが https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py で、僕はこれを改造しています。スタンフォード大学の学生さんも、VGG風です。tawago さんは不明。VGG風なのかな?

VGG風というやつは「畳み込み → 畳み込み → Maxプーリング」を繰り返して最後に普通のニューラルネットワークです。

畳み込みなどを軽く説明しますと、こうなります。詳細は書籍「ゼロから作るDeep Learning」をご覧ください。

  • 畳み込み:各点の近傍(3x3など)をとって、1次元ベクトルに変換してパラメータと内積
  • Maxプーリング:各点の近傍(2x2など)をとって、その中の最大値

ETL8G が160人分のデータしかなく、大きなデータセットではないです。一般的にデータが少ない時はパラメータ数の多い複雑なモデルを使っても上手く行かないので、僕は Keras のサンプルそのままの単純なやつを使いました。「畳み込み → 畳み込み → Maxプーリング」は2回です。

ノイズ関数

汎化能力を高める方法として、訓練時のみノイズを加えるという方法があります。僕が今回使ったのは、この2つです。

  • Dropout(確率pで1/p倍、確率1-pで0倍して消す)
  • 教師画像に回転(±15度)、ズーム(0.8~1.2倍)

Dropoutの方はスタンフォード大学の学生さんも使ってます。僕は全てp=0.5、つまり50%の確率で2倍、50%の確率で0にする、です。

入力画像に回転・ズームは、回転・ズームしても文字としては同じだよ、ということを学習させるためで、Keras のサンプルコードでも使ってますし、僕も使いました。これも効果は大きいです。スタンフォード大学の学生さんは使ってませんでした。

残りの細かい話

  • 画像は32x32に縮小しました。これだけあれば十分ですし、大きいと計算量が増えるので。スタンフォード大学の学生さんは64x64にしてました。あと、畳み込みの回数の関係のバランスで、ピクセル数が多い事が改善につながるというわけではないです。
  • Keras のデフォルトだと、初期値がおかしく、学習が進まないので、標準偏差 0.1 の正規分布にしました。
  • 確率的勾配降下法は Adam のようなピョンピョン跳ねるやつは上手く行かないので、最初、穏やかな AdaGrad にしていたのですが、学習率を入れなくても良い AdaGrad の変形の AdaDelta の方が良かったのでそっちを使いました。

結論

というわけで、教科書通りの事やって、99.78%になりました。手書き数字は MNIST のデータセットでは99.77%と報告されていて http://yann.lecun.com/exdb/mnist/ ほぼ同じです。漢字とか他のは僕はやっていませんが、スタンフォード大学の学生さんが99.64%と言っていて、これよりは少しだけ良くなるでしょう。

ディープラーニング、手書き文字は、ほぼ完璧に認識できちゃうんですね!

768
771
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
768
771