動機
前回「Ruby でニューラルネットワーク」では、MNIST のデータを使って、数字認識ができるニューラルネットワークを構築しました。元ネタは、名著「ゼロから作るDeep Learning ――Pythonで学ぶディープラーニングの理論と実装」です。
ただ、数字の認識を行えるとはいっても、訓練用のデータもテスト用のデータも MNIST から与えられたものにすぎません。自分で実際に手書きの数字を描いてみたらどうなるのだろう?という興味から、ウェブ上で、実際に数値を手書きして、そのまま数字認識が試せるようなものを作ってみました。
Digit Recognition
ソースコード(GitHub)
手書き数値の右隣の "3: 99.6%" みたいなところは、たとえば「『3』である確率は 99.6%」ということを意味します。確率の一番高い数字を認識結果としています。
コード解説
neuralnet.js
がメインのコードですが、その中でも実質的にニューラルネットワークの計算を行っているのはこの部分だけです。
function predict(x) {
var a1 = nj.dot(x, net_params.w1).add(net_params.b1);
var z1 = nj.sigmoid(a1);
var a2 = nj.dot(z1, net_params.w2).add(net_params.b2);
var z2 = nj.sigmoid(a2);
var a3 = nj.dot(z2, net_params.w3).add(net_params.b3);
var z3 = nj.softmax(a3);
return z3;
}
3層のニューラルネットワークです。net_params.js
で、重みとバイアスの数値を読み込んで、グローバル変数の net_params
に代入しています。この重みとバイアスの数字自体は、上記の「ゼロから作るDeep Learning ――Pythonで学ぶディープラーニングの理論と実装」から得たものです。
nj.*
で始まる関数は、NumJs というライブラリのものです。NumJs は名前の通り、Python の行列ライブラリ NumPy 風のインターフェイスを持った JavaScript のライブラリです。とても使いやすかったです。おすすめ。
感想
実際の認識精度はというと・・・残念ながらそれほど高くないようです。数字によって明らかに認識率が違います。「3」は、かなりよく認識されるのですが、「4」はほぼ壊滅的です。もともと、MNIST のテストデータを使った正答率は、96% 程度なので、本来はもっとよく認識されてもよさそうなものです。
おそらくは、MNIST の画像と今回の手書きの画像では、なんらかの「画風」の違いがあるのでしょう。もちろん、今回風の画像での学習は行っていませんので、正答率が低めになっているのかもしれません。あるいは、CNNのようなより深層なネットワークを使えば、正答率を上げることができるのかもしれません。今後の研究課題ですね。
それにしても、NumJs 。元々は、math.jsという数値計算ライブラリを使っていたのですが、reshape の処理がうまくできなかったので、NumJs に切り替えました。NumPy に非常に近い使い勝手があって、簡単な行列計算ならなんでもこなせそうです。Ruby にもほしいです…(Ruby の NMatrix のほうが使いにくい印象)。
今回、実際にウェブサイト(というほど大仰なものではありませんが)を作ってみて、HTML5 の Canvas の使い方や、静的コンテンツを Heroku にホストする方法などを理解することができました。JavaScript には苦手意識があったのですが、実際、腰を入れて取り組んでみると、これはこれで結構面白いと思いました。
今後の展望
ニューラルネットワークの仕組みは、実装レベルで理解できたので、これからは、もう少し高レベルに視点を移して、Chainer や TensorFlow と言った、ディープラーニング用のライブラリについて調べてみるつもりです。