一層のニューラルネットワーク(?)でMNISTの分類をしたいと思います。
MNISTデータの前処理
MNISTのデータは,バイナリ形式なので,これを読み取るための種類を作ります。バイナリのフォーマットは,MNISTのウェブページの一番下に書かれています。
プロデルでバイナリデータを読み込む上で注意しないといけないのは,バイナリ関連の手順がすべて保護機能になっている点です。セキュリティ関連の歴史的経緯から,開発者登録をしないとこの機能が使えません。登録はこのページから一瞬でできます。
まず,読み取ったバイトデータを32ビット整数に変換する手順を作ります。
【バイト:配列】を整数化する手順
【出力:整数】
バイトの個数回,【カウンタ】にカウントして繰り返す
出力は,出力+バイト(カウンタ)×256^(バイトの個数-カウンタ)
繰り返し終わり
出力を返す。
終わり
ビッグエンディアンであると書かれているので,このような処理になりました。
次に,この手順を使ってデータを読み込んでいきます。ファイルの最初の4バイトがマジックナンバーとなっており,2049の場合はラベルデータ,2051の場合は画像データです。
ラベルのデータは,one-hotベクトルにする必要があるので,そのようにしています。
MNIST読取器とは
【MNIST】は,バイナリファイルを作ったもの。
はじめ(ファイル)の手順
MNISTへファイルを読み取り専用で開く。
マジックナンバーは,{[MNISTから読み取ったもの],[MNISTから読み取ったもの],[MNISTから読み取ったもの],[MNISTから読み取ったもの]}を整数化したもの。
【データ数】は,{[MNISTから読み取ったもの],[MNISTから読み取ったもの],[MNISTから読み取ったもの],[MNISTから読み取ったもの]}を整数化したもの。
もしマジックナンバーが2049なら ’ラベル
データは,行列(データ数,10)を作ったもの。
データ数回,【現在行】にカウントして繰り返す
【添字】は,MNISTから読み取ったもの+1。
【一行】は,{0,0,0,0,0,0,0,0,0,0}。
一行(添字)は,1。
データの中身(現在行)は,一行。
繰り返し終わり
他でもしマジックナンバーが2051なら ’画像
【画像サイズ】は,({[MNISTから読み取ったもの],[MNISTから読み取ったもの],[MNISTから読み取ったもの],[MNISTから読み取ったもの]}を整数化したもの)×({[MNISTから読み取ったもの],[MNISTから読み取ったもの],[MNISTから読み取ったもの],[MNISTから読み取ったもの]}を整数化したもの)。
データは,行列(データ数,画像サイズ)を作ったもの。
データ数回,【現在行】にカウントして繰り返す
【画像一つ】は,{}。
画像サイズ回,【カウンタ】にカウントして繰り返す
画像一つ(カウンタ)は,MNISTから読み取ったもの。
繰り返し終わり
データの中身(現在行)は,画像一つ。
繰り返し終わり
もし終わり
終わり
終わり
これでMNISTのデータを読み込めるようになりました。
学習をする
学習データが60000個もあるので,インタプリタだとデータの読み込みだけで10分以上かかります(当然コンピュータの性能によります)。従って,コンパイルしてから実行することをオススメします。
バッチサイズは,100。
反復数は,60000。
学習率は,0.02。
(反復数÷バッチサイズ)回,【カウンタ】にカウントして繰り返す
ミニバッチデータの中身は,訓練データのデータの中身の((カウンタ-1)×バッチサイズ+1)番目からバッチサイズ個切り出したもの。
ミニバッチラベルの中身は,訓練ラベルのデータの中身の((カウンタ-1)×バッチサイズ+1)番目からバッチサイズ個切り出したもの。
勾配は,ネットワークでミニバッチラベルが正解でミニバッチデータにおける勾配を計算したもの。
勾配の個数回,【キー】にカウントして繰り返す
ネットワークのパラメータ(キー)から,学習率を勾配(キー)に全部かけたものを破壊的に引く。
繰り返し終わり
「[カウンタ]/[反復数÷バッチサイズ]」を表示。
ネットワークでミニバッチラベルが正解でミニバッチデータにおける損失を計算したものを表示。
繰り返し終わり
今回はデータが多いので,100個ずつのミニバッチを作っています。
ちなみに,損失の計算にもそこそこ時間がかかるので,表示しないという選択もあるかもしれません。
学習には1時間半くらいかかりました(Xeon E3-1505M v5 2.80GHz)。
推定をする
評価用データも10000個あるので,結構かかります。
推定結果は,[ネットワークで[評価データのデータ]から推定したもの]が最大値となる添字を転置したもの。
評価ラベルのデータが最大値となる添字を転置したもので推定結果を比較したものを表示。
出力
途中省略しています。
PS C:\*****> .\MNIST_NN.exe
評価用ラベルの読み込み完了
評価用データの読み込み完了
訓練ラベルの読み込み完了
訓練データの読み込み完了
1/600
18.7430426599258
2/600
17.2233298105334
3/600
13.6773555012828
4/600
13.6773554545378
5/600
15.1970616161542
(略)
595/600
3.54598104376164
596/600
2.27955924241463
597/600
4.55911848483926
598/600
3.54598104376164
599/600
-1.00000008273537E-11
600/600
1.01313744106761
0.8311
ということで,正解率83.1%でした。意外と線形でも正解率出るものですね。MNISTのウェブページには,線形分類器のエラーレートが12%と書かれているので,ぼちぼちではないかと思います。