LoginSignup
3
0

More than 1 year has passed since last update.

家にきた赤い服のその人は、本当にサンタですか? 〜tflite modelを作って判別しよう〜

Last updated at Posted at 2021-12-24

メリークリスマス

アドベントカレンダーのネタを探してたら、kaggleで IS THAT SANTA? というデータセットを見つけてしまったので、on-device向けのサンタ識別モデルを作成することにしました。

準備

自分でモデルを組んでも良かったんですが、ここはお手軽にTensorFlow Lite Model Makerを使うことにしました。

TensorFlow Lite Model Makerは転移学習を用いてtensorflow liteのモデルを作成できるライブラリで、少ないデータ量で精度が高いモデルを作成することができます。

また、モデル作成の実行環境は、Colaboratoryを使いました。

TensorFlow Lite Modelのチュートリアルがよくできていて、これで学習するとすぐに使えるようになります。

学習開始

Tensor Lite Model Maker をColoab上でインストールします。

!pip install -q tflite-model-maker

パッケージをインポートします。

import os
import tensorflow as tf

from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader

import matplotlib.pyplot as plt

IS THAT SANTA?から、学習用のデータをダウンロードします。

Colab上で、Kaggleから直接DLすることができなかったので、一度KaggleからDLしたzipファイルをCloud Storageに置いてDLするように工夫しました。Colabに直接アップロードするとか、他にも方法あると思います。

image_path = tf.keras.utils.get_file(
      'is_that_santa.zip',
      'https://example.com/is_that_santa.zip',
      extract=True)
train_image_path = os.path.join(os.path.dirname(image_path), 'is that santa/train')
test_image_path = os.path.join(os.path.dirname(image_path), 'is that santa/test')

データの構造ですが、以下のようになってます。

!sudo apt-get install tree

!tree /root/.keras/datasets/is\ that\ santa/
/root/.keras/datasets/is that santa/
├── test
│   ├── not-a-santa
│   │   ├── 0.not-a-santa.jpg
│   │   ├── 100.not-a-santa.jpg
...
│   └── santa
│       ├── 0.Santa.jpg
│       ├── 100.Santa.jpg
...
└── train
    ├── not-a-santa
    │   ├── 101.not-a-santa.jpg
    │   ├── 102.not-a-santa.jpg
...
    └── santa
        ├── 101.Santa.jpg
        ├── 102.Santa.jpg

数が多いので、省略していますが、traintest でデータが分かれていて、その中で not-a-santasantaで分けられており、それぞれディレクトリ名がラベル名を表しています。

中の画像データはこんな感じです。

image.png

train, testデータそれぞれ、DataLoaderを利用して読み込みます。

train_data = DataLoader.from_folder(train_image_path)
test_data = DataLoader.from_folder(test_image_path)

モデルを学習させます。

model = image_classifier.create(train_data)

ここで問題が発生しました。

InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.

画像のファイル形式を以下のコマンドで調べみました。

!ls -d1 /root/.keras/datasets/is\ that\ santa/train/santa/* | xargs -i file "{}" | grep -v "JPEG" | grep -v "PNG"
!ls -d1 /root/.keras/datasets/is\ that\ santa/train/not-a-santa/* | xargs -i file "{}" | grep -v "JPEG" | grep -v "PNG"
!ls -d1 /root/.keras/datasets/is\ that\ santa/test/santa/* | xargs -i file "{}" | grep -v "JPEG" | grep -v "PNG"
!ls -d1 /root/.keras/datasets/is\ that\ santa/test/not-a-santa/* | xargs -i file "{}" | grep -v "JPEG" | grep -v "PNG"

/root/.keras/datasets/is that santa/train/not-a-santa/463.not-a-santa.jpg: Windows desktop.ini, ASCII text, with CRLF line terminators
/root/.keras/datasets/is that santa/test/santa/401.Santa.jpg: RIFF (little-endian) data, Web/P image, VP8 encoding, 863x1300, Scaling: [none]x[none], YUV color, decoders should clamp
/root/.keras/datasets/is that santa/test/not-a-santa/395.not-a-santa.jpg: Windows desktop.ini, ASCII text, with CRLF line terminators

確認すると、なんとjpgの拡張子なのに、全然違うファイル形式のものがありました。

あわてん坊のサンタクロースなのかな?

仕方ないので手で除外し、再度読み込み直します。合計3つありました。
ちなみにjpgの拡張子でPNG形式の画像もありましたが、それは読み込めはするので、そのままにしました。

!rm /root/.keras/datasets/is\ that\ santa/train/not-a-santa/463.not-a-santa.jpg
!rm /root/.keras/datasets/is\ that\ santa/test/santa/401.Santa.jpg
!rm /root/.keras/datasets/is\ that\ santa/test/not-a-santa/395.not-a-santa.jpg
# 再度読み込み、学習させる
train_data = DataLoader.from_folder(train_image_path)
test_data = DataLoader.from_folder(test_image_path)

model = image_classifier.create(train_data)

# summary
model.summary()
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 hub_keras_layer_v1v2_5 (Hub  (None, 1280)             3413024   
 KerasLayerV1V2)                                                 

 dropout_5 (Dropout)         (None, 1280)              0         

 dense_5 (Dense)             (None, 2)                 2562      

=================================================================
Total params: 3,415,586
Trainable params: 2,562
Non-trainable params: 3,413,024
_________________________________________________________________

評価

テストデータで評価してみます。

loss, accuracy = model.evaluate(test_data)
20/20 [==============================] - 35s 2s/step - loss: 0.2750 - accuracy: 0.9642

loss: 0.2750
accuracy: 0.9642

なかなかの精度がでました。

以下、画像をみてみます。
not-a-santasanta の2ラベルの分類で、合計が100%になるような確率で推論されるため、高い方の確率のラベルを予測結果として、テストデータの正解と相違があるものを赤文字で表示しています。(つまり間違えてるとする)

santa.png

Androidに組み込む

Tensorflow Liteのモデルファイルをエクスポートします。

model.export(export_dir='.', tflite_filename='classify_santa_model.tflite')

Android StudioはML Model Bindingという機能があり、tfliteのモデルを簡単に入れることができます。

Tensorflow Liteはモデルの入力、出力の形式などの情報をメタデータとして付与することができます。

AndroidStudioがそのメタデータを読み取ってコードを生成してくれるため、BitmapなどでAndroidアプリ開発者が簡単に実装できるようなります。

Tensorflow Lite Model Makerでモデルを作ると、メタデータも一緒に付与されるようでした。自分でメタデータをつける場合は、ライブラリを駆使してつけないといけないため、これはかなり嬉しいです。

実際にAndroid Studioで読み込むと、以下のようにモデルの情報が表示され、サンプルコードが表示されます。

image.png

あとはサンタと判定する確率の閾値をどうするかは要件次第ですが、こちらがアプリに組み込んでみた時のプレビューとscoreです。

output.gif

まとめ

もしあなたがこっそり夜起きていて、サンタっぽい人に出くわすことがあったら、そっとこのモデルを組み込んだアプリを使って、本当にサンタかどうか判別すると良いでしょう!

メリクリ。

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0