メリークリスマス
アドベントカレンダーのネタを探してたら、kaggleで IS THAT SANTA? というデータセットを見つけてしまったので、on-device向けのサンタ識別モデルを作成することにしました。
準備
自分でモデルを組んでも良かったんですが、ここはお手軽にTensorFlow Lite Model Makerを使うことにしました。
TensorFlow Lite Model Makerは転移学習を用いてtensorflow liteのモデルを作成できるライブラリで、少ないデータ量で精度が高いモデルを作成することができます。
また、モデル作成の実行環境は、Colaboratoryを使いました。
TensorFlow Lite Modelのチュートリアルがよくできていて、これで学習するとすぐに使えるようになります。
学習開始
Tensor Lite Model Maker をColoab上でインストールします。
!pip install -q tflite-model-maker
パッケージをインポートします。
import os
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
IS THAT SANTA?から、学習用のデータをダウンロードします。
Colab上で、Kaggleから直接DLすることができなかったので、一度KaggleからDLしたzipファイルをCloud Storageに置いてDLするように工夫しました。Colabに直接アップロードするとか、他にも方法あると思います。
image_path = tf.keras.utils.get_file(
'is_that_santa.zip',
'https://example.com/is_that_santa.zip',
extract=True)
train_image_path = os.path.join(os.path.dirname(image_path), 'is that santa/train')
test_image_path = os.path.join(os.path.dirname(image_path), 'is that santa/test')
データの構造ですが、以下のようになってます。
!sudo apt-get install tree
!tree /root/.keras/datasets/is\ that\ santa/
/root/.keras/datasets/is that santa/
├── test
│ ├── not-a-santa
│ │ ├── 0.not-a-santa.jpg
│ │ ├── 100.not-a-santa.jpg
...
│ └── santa
│ ├── 0.Santa.jpg
│ ├── 100.Santa.jpg
...
└── train
├── not-a-santa
│ ├── 101.not-a-santa.jpg
│ ├── 102.not-a-santa.jpg
...
└── santa
├── 101.Santa.jpg
├── 102.Santa.jpg
数が多いので、省略していますが、train
と test
でデータが分かれていて、その中で not-a-santa
と santa
で分けられており、それぞれディレクトリ名がラベル名を表しています。
中の画像データはこんな感じです。
train, testデータそれぞれ、DataLoaderを利用して読み込みます。
train_data = DataLoader.from_folder(train_image_path)
test_data = DataLoader.from_folder(test_image_path)
モデルを学習させます。
model = image_classifier.create(train_data)
ここで問題が発生しました。
InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.
画像のファイル形式を以下のコマンドで調べみました。
!ls -d1 /root/.keras/datasets/is\ that\ santa/train/santa/* | xargs -i file "{}" | grep -v "JPEG" | grep -v "PNG"
!ls -d1 /root/.keras/datasets/is\ that\ santa/train/not-a-santa/* | xargs -i file "{}" | grep -v "JPEG" | grep -v "PNG"
!ls -d1 /root/.keras/datasets/is\ that\ santa/test/santa/* | xargs -i file "{}" | grep -v "JPEG" | grep -v "PNG"
!ls -d1 /root/.keras/datasets/is\ that\ santa/test/not-a-santa/* | xargs -i file "{}" | grep -v "JPEG" | grep -v "PNG"
/root/.keras/datasets/is that santa/train/not-a-santa/463.not-a-santa.jpg: Windows desktop.ini, ASCII text, with CRLF line terminators
/root/.keras/datasets/is that santa/test/santa/401.Santa.jpg: RIFF (little-endian) data, Web/P image, VP8 encoding, 863x1300, Scaling: [none]x[none], YUV color, decoders should clamp
/root/.keras/datasets/is that santa/test/not-a-santa/395.not-a-santa.jpg: Windows desktop.ini, ASCII text, with CRLF line terminators
確認すると、なんとjpgの拡張子なのに、全然違うファイル形式のものがありました。
あわてん坊のサンタクロースなのかな?
仕方ないので手で除外し、再度読み込み直します。合計3つありました。
ちなみにjpgの拡張子でPNG形式の画像もありましたが、それは読み込めはするので、そのままにしました。
!rm /root/.keras/datasets/is\ that\ santa/train/not-a-santa/463.not-a-santa.jpg
!rm /root/.keras/datasets/is\ that\ santa/test/santa/401.Santa.jpg
!rm /root/.keras/datasets/is\ that\ santa/test/not-a-santa/395.not-a-santa.jpg
# 再度読み込み、学習させる
train_data = DataLoader.from_folder(train_image_path)
test_data = DataLoader.from_folder(test_image_path)
model = image_classifier.create(train_data)
# summary
model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
hub_keras_layer_v1v2_5 (Hub (None, 1280) 3413024
KerasLayerV1V2)
dropout_5 (Dropout) (None, 1280) 0
dense_5 (Dense) (None, 2) 2562
=================================================================
Total params: 3,415,586
Trainable params: 2,562
Non-trainable params: 3,413,024
_________________________________________________________________
評価
テストデータで評価してみます。
loss, accuracy = model.evaluate(test_data)
20/20 [==============================] - 35s 2s/step - loss: 0.2750 - accuracy: 0.9642
loss: 0.2750
accuracy: 0.9642
なかなかの精度がでました。
以下、画像をみてみます。
not-a-santa
と santa
の2ラベルの分類で、合計が100%になるような確率で推論されるため、高い方の確率のラベルを予測結果として、テストデータの正解と相違があるものを赤文字で表示しています。(つまり間違えてるとする)
Androidに組み込む
Tensorflow Liteのモデルファイルをエクスポートします。
model.export(export_dir='.', tflite_filename='classify_santa_model.tflite')
Android StudioはML Model Bindingという機能があり、tfliteのモデルを簡単に入れることができます。
Tensorflow Liteはモデルの入力、出力の形式などの情報をメタデータとして付与することができます。
AndroidStudioがそのメタデータを読み取ってコードを生成してくれるため、BitmapなどでAndroidアプリ開発者が簡単に実装できるようなります。
Tensorflow Lite Model Makerでモデルを作ると、メタデータも一緒に付与されるようでした。自分でメタデータをつける場合は、ライブラリを駆使してつけないといけないため、これはかなり嬉しいです。
実際にAndroid Studioで読み込むと、以下のようにモデルの情報が表示され、サンプルコードが表示されます。
あとはサンタと判定する確率の閾値をどうするかは要件次第ですが、こちらがアプリに組み込んでみた時のプレビューとscoreです。
まとめ
もしあなたがこっそり夜起きていて、サンタっぽい人に出くわすことがあったら、そっとこのモデルを組み込んだアプリを使って、本当にサンタかどうか判別すると良いでしょう!
メリクリ。