概要
今回イチからディープラーニングについて学ぶ事になり、MNISTの学習モデルを作成しました。
ただそれだけだと面白みに欠けるので、HTMLも使って自分で書いた数字を認識させてみたり、その精度を上げる為にデータオーグメンテーション等の試行錯誤を行ってみたので、その過程を記事にしてみました。
MNISTの学習
ここはあまり本題ではないので、ざっくりになります。
Google Colaboratory上で、kerasを利用してニューラルネットワークを構築し、MNISTデータセットの学習をしました。
その際のモデルの構造と各パラメータは以下の通りです。
model.add(Conv2D(16, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
history = model.fit(x_train, y_train,
batch_size=128,
epochs=10,
verbose=1,
validation_data=(x_valid, y_valid))
今回は畳み込みニューラルネットワークを構築してみました。このモデルの損失値と精度は以下の通りです。
Loss : 0.0321417897939682
Accuracy : 0.9893749952316284
数値で見るとそれほど悪くはないように見えますね。
このモデルを後ほど使って自分で書いた数字を認識させようと思うので、Google Driveに保存しておきます。
drive.mount('/content/drive')
model.save('/content/drive/My Drive/~~~/model_name.h5')
手書き領域の実装
次はColaboratory上で手書き数字を書けるように実装していくのですが、ここで悩んだ点がいくつか出てきました。
- Colaboratory上でHTMLを動かす方法
- HTML側で受け取った値をpython側に渡す方法
- そもそもの手書き領域の実装どうするか
HTMLを動かす方法
まずHTMLを動かす方法については、IPythonライブラリのdisplayとHTML関数を利用して実装しました。
html = HTML('''
~~
好きなようにHTMLコードを書く
~~
''')
display(html)
これでpythonからHTMLコードを実行することが出来ました。
HTMLからpythonへの値受け渡し
次にHTML側からpythonへの値受け渡しに関してですが、
https://colab.research.google.com/notebooks/snippets/advanced_outputs.ipynb#scrollTo=QS5x4lFf0fJE
こちらを参考に、google.colabのoutput.register_callback()
関数を利用して実装しました。
def predictImage(imageCode):
hogehoge
~~
output.register_callback('PredictImage', predictImage)
python上で定義した関数をcallbackとして登録しておいて、
document.getElementById("confirmButton").addEventListener("click", () => {
google.colab.kernel.invokeFunction('PredictImage', [canvas.toDataURL()], {})
});
例えばボタンが押された時などに呼び出す様な構造にしておけば、HTML側からpython側のコードを呼び出すことが出来ます。この時引数に値をセット出来るので、動的な値の受け渡しが出来るという形です。
手書き領域の実装
最後に手書き領域の実装に関してです。
普通にイチから書くと手間がかかりそうだったので、Fabric.jsというライブラリを利用することにしました。
http://fabricjs.com/
<canvas id="canvas" height="140" width="140" style="border-style: solid; border-color: black;"></canvas>
<div>
<button type="button" id="confirmButton" style="margin-top:20px">GO!!</button>
</div>
<!-- Fabric.jsの読み込み -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/fabric.js/4.5.0/fabric.min.js"></script>
<script>
const canvas = new fabric.Canvas("canvas", {isDrawingMode: true, backgroundColor: 'black'});
canvas.freeDrawingBrush.width = 5;
canvas.freeDrawingBrush.color = 'white';
document.getElementById("confirmButton").addEventListener("click", () => {
google.colab.kernel.invokeFunction('PredictImage', [canvas.toDataURL()], {})
});
</script>
結果としては、HTMLとjs部分はこれだけに収まりました。Fabric.jsを読み込んで、canvasを用意して、後はjsの中で背景色や文字色を設定するだけです。
canvasに描かれたデータは、canvas.toDataURL()
でpythonに渡す様にしました。
モデルに手書き数字を予測させる
さて、ここからはpython側の実装です。
まずはcanvasに描かれたデータがどの様な形式になっているのかを確認してみます。
 ……
base64形式でエンコードされたデータになっていました。
最初のdata:image/png;base64,
が余計なのでここを削除してからデコードします。
decodedImage = base64.b64decode(re.sub('data:image/png;base64,', '', imageCode))
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x8c\x00\x00\x00\x8c\x08\x06\x00\x00\x00\xae\xc0A> ……
これで画像として扱えるデータになりました。
ここから最初に作成したモデルに入力できる形にデータ形式を整えていく為に、PILで読み込んで加工していきます。
# pilで読み込み
pil_image = Image.open(BytesIO(decodedImage))
# グレイスケール
pil_image_gray = pil_image.convert("L")
# リサイズ
resized_image = pil_image_gray.resize((28, 28))
canvasで描かれたデータはRGBAの4チャネルのデータが入っていたのでグレイスケールに変換してから、モデルに合う様に[28, 28]の行列に直します。
最後にnp.arrayに変換して、[1, 28, 28, 1]の形に直したら加工は終わりです。
np_image = np.array(resized_image, dtype=float)
# 正規化
np_image /= 255
今回はモデルを学習させる時に0~1の範囲で正規化をしたので、同じく正規化処理を入れておきます。
一応[28, 28]の形になった手書き数字を見てみましょう。
ではいよいよモデルに予測させてみます!
model = tf.keras.models.load_model('/content/drive/My Drive/~~~/model_name.h5')
result = model.predict(reshaped_np_image, batch_size=1, verbose=0)
prediction = result[0].argmax()
print('予測結果: ', prediction)
わりと予測出来てそうですね!
ただ、しばらく書いて予測させてみると色々問題が見えてきたのでした。
精度を見てみる
書き方や数字によっては、結構認識されない場合がありました。
また、数字を書く位置を真ん中からズラすと全く認識されません。
この辺りの認識精度を改善するために色々試行錯誤してみます。
ニューラルネットワークの改善をしてみる
まずはニューラルネットワークをよりディープなものにしてみます。
model.add(Conv2D(16, kernel_size=(3, 3), padding='same', activation='relu', input_shape=(28, 28, 1)))
model.add(Conv2D(16, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(32, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
history = model.fit(x_train, y_train,
batch_size=128,
epochs=10,
verbose=1,
validation_data=(x_valid, y_valid))
VGG16を参考にしたモデルを構築して、認識精度がどれくらい改善するかを見てみます。
チャネル数が少ない様に思えますが、これ以上増やすと逆に精度が下がってしまうという現象が起こったので、MNISTの様な単純なデータセットにはこれくらいで良いのかもしれません。
多少改善されたようには思えますが、やはり中央からズラした位置に書くと全く認識されません。
データオーグメンテーションをしてみる
データオーグメンテーションで教師データを加工して、ズレた位置にある数字を学習するようにしてみます。
今回はkerasのImageDataGenerator
を使って実装してみました。
https://keras.io/ja/preprocessing/image/
idg = image.ImageDataGenerator(height_shift_range=0.3, width_shift_range=0.3)
generator = idg.flow(x_train, y_train, batch_size=128)
ImageDataGeneratorの引数でどうデータを加工するかのパラメータを渡します。
今回は数字の位置をズラしたいので、ランダムで縦横方向に元画像の最大0.3倍ずつズラす様な設定をしました。
試しに加工された画像を見てみるとこんな感じです。
では、モデルを学習させます。構造は先ほどのVGG16モデルの形式のままです。
history = model.fit(generator,
batch_size=128,
epochs=10,
verbose=1,
validation_data=(x_valid, y_valid))
fit_generator
という関数もあるみたいですが、こちらは現在非推奨となっている様なので、そのままfit関数にgeneratorを入れて学習しました。
精度を見てみます。
数字を端に寄せて書いても認識してくれる様になりました!
最初に作ったモデルと比較するとかなり性能が改善されたのが分かります。
さいごに
今回作成したMNIST学習、データオーグメンテーション、手書き数字予測までを含んだプログラムをNotebookに公開しておきます。
モデルの学習に1時間とかかかるのでそこはご了承ください。Google Driveにモデルを保存する場合は、コード内のコメントアウト部分を活性化させてください。
https://colab.research.google.com/drive/1WsaPia0xHyQ31WHIfMNa5rqz7-TScGaI?usp=sharing