5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PythonAdvent Calendar 2023

Day 3

Keras の EfficientNetV2 で学習した画像分類モデルを TensorFlow.js で動作させる

Last updated at Posted at 2023-12-02

経緯

画像分類をブラウザで動かしているアプリを作っています。これまでは Google AutoML Vision を使って画像分類モデルを作っていましたが、毎回それを使うとお金がかかるし、手元に NVIDIA GeForce RTX 3080 を搭載したゲーミングPCがあるので、今後はそれを使ってモデルを作成することにしました。

画像分類をブラウザで動かしているアプリの内容はこの記事の本質ではないですが、気になる人はこちらの記事をご参照ください。

使用技術

環境

WSL2 の Ubuntu 22.04 LTS に CUDA と cuDNN をインストールした環境を構築しました。構築方法はこちらの記事をご参照ください。

画像分類モデルの作成、検証

Python 言語で TensorFlow を使いモデルを作成します。学習および検証を行う画像は JPEG 形式でローカルディスクに保存されていますが、毎回 JPEG デコードと画像縮小を行うのは処理時間がかかるため、h5py を使い HDF5 形式でモデルの入力層に合わせた配列を学習前に保存しています。そのときのプログレス表示には tqdm を使用しています。また混同行列の計算に scikit-learn を使い表示に Pandas を使用しています。

ブラウザ上で推論を動かす

ブラウザでは TensorFlow.js を使って推論を行います。TensorFlow.js 向けのモデルは SavedModel から tensorflowjs パッケージを使い変換します。

使用パッケージ

パッケージマネージャーには Poetry を使用しています。今回の使用技術をすべて列挙した pyproject.toml ファイルの [tool.poetry.dependencies] テーブルはこのようになりました。

pyproject.toml
[tool.poetry.dependencies]
python = ">=3.11,<3.12"
tensorflow = "^2.14.0"
pillow = "^10.1.0"
numpy = "^1.26.1"
h5py = "^3.10.0"
tqdm = "^4.66.1"
scikit-learn = "^1.3.2"
pandas = "^2.1.2"
tensorflowjs = "^4.12.0"

pillowtf.keras.utils.load_img 関数で使用します。

使用モデル

Keras にはいくつかのモデルが搭載されていますが、比較的最近の登場でWebアプリに載せることを想定してサイズが小さめの EfficientNetV2 を選定しました。技術的詳細については私には解説が難しいので、Qiita にある解説記事を紹介します。

Keras の EfficientNetV2 には B0, B1, B2, B3, S, M, L と複数のサイズがありますが、分類精度やアプリとして許容できるモデルサイズを考慮して B2 を選択しました。

学習および検証データを HDF5 形式で保存する

ここから先は前項で紹介した技術の使い方を詳細に解説していきます。

今回の分類先ラベルを data.py に定義します。(意味が気になる人はこちらをご参照ください。)

data.py
LABELS = ["start", "end", "kill", "death", "other"]

今回の学習画像は JPEG 形式で 52372 枚あり、ファイル名とラベル名が CSV ファイルに記録されています。

data/images.csv
frame000090.jpg,start
frame000105.jpg,other
frame000450.jpg,kill
frame000975.jpg,death
frame001712.jpg,end

さらに必要な定数を定義しました。

data.py
# 入力画像サイズ(AutoML Vision で作成したものと同じ)
INPUT_IMAGE_SIZE = 224
# カラー画像を入力にする
INPUT_IMAGE_CHANNEL = 3
# HDF5 形式のファイルの保存パス
DATASET_PATH = "dataset.hdf5"
# 学習データの入力層の配列のデータセット名
DATASET_TRAIN_XS = "train_xs"
# 学習データの出力層の配列のデータセット名
DATASET_TRAIN_YS = "train_ys"
# テストデータの入力層の配列のデータセット名
DATASET_TEST_XS = "test_xs"
# テストデータの出力層の配列のデータセット名
DATASET_TEST_YS = "test_ys"

学習およびテストに使う HDF5 形式のファイルを作成します。

make_dataset.py
import csv
import os
import h5py
import numpy as np
import tensorflow as tf
import data
from tqdm import tqdm

CSV_PATH = "data/images.csv"
IMAGE_DIR = "data/images/"


@dataclass
class LabelPath:
    label: str
    path: str


# CSV ファイルから画像のパスとラベルを取得する
all_images: list[LabelPath] = []
with open(CSV_PATH) as f:
    for row in csv.reader(f):
        filename = row[0]
        label = row[1]
        if label in data.LABELS:
            path = os.path.join(IMAGE_DIR, filename)
            all_images.append(LabelPath(label=label, path=path))
# all_images を訓練データとテストデータに分ける
# テストデータは 5000 枚、残りは訓練データ
train = all_images[:-5000]
test = all_images[-5000:]
# HDF5 ファイルを作成する
with h5py.File(data.DATASET_PATH, "w") as h:
    # あらかじめデータセットをサイズを指定して作成する
    # 訓練データの入力
    train_xs = h.create_dataset(
        data.DATASET_TRAIN_XS,
        shape=(
            len(train),
            data.INPUT_IMAGE_SIZE,
            data.INPUT_IMAGE_SIZE,
            data.INPUT_IMAGE_CHANNEL,
        ),
        dtype=np.uint8,
    )
    # 訓練データの出力
    train_ys = h.create_dataset(
        data.DATASET_TRAIN_YS,
        shape=(len(train), len(data.LABELS)),
        dtype=np.uint8,
    )
    # テストデータの入力
    test_xs = h.create_dataset(
        data.DATASET_TEST_XS,
        shape=(
            len(test),
            data.INPUT_IMAGE_SIZE,
            data.INPUT_IMAGE_SIZE,
            data.INPUT_IMAGE_CHANNEL,
        ),
        dtype=np.uint8,
    )
    # テストデータの出力
    test_ys = h.create_dataset(
        data.DATASET_TEST_YS, shape=(len(test), len(data.LABELS)), dtype=np.uint8
    )
    # データセットに1画像1ラベルずつ書き込む
    # 訓練データを書き込む
    for index, label_path in enumerate(tqdm(train)):
        # 画像を読み込む
        image = tf.keras.utils.load_img(
            label_path.path,
            target_size=(data.INPUT_IMAGE_SIZE, data.INPUT_IMAGE_SIZE),
        )
        # 入力層の配列に変換する
        x = tf.keras.utils.img_to_array(image, dtype=np.uint8)
        # 出力層の配列に変換する
        # 例: death -> [0, 0, 0, 1, 0]
        y = tf.keras.utils.to_categorical(
            data.LABELS.index(label_path.label), len(data.LABELS)
        )
        # データセットに書き込む
        train_xs[index] = x
        train_ys[index] = y
    # テストデータを書き込む
    for index, label_path in enumerate(tqdm(test)):
        image = tf.keras.utils.load_img(
            label_path.path,
            target_size=(data.INPUT_IMAGE_SIZE, data.INPUT_IMAGE_SIZE),
        )
        x = tf.keras.utils.img_to_array(image, dtype=np.uint8)
        y = tf.keras.utils.to_categorical(
            data.LABELS.index(label_path.label), len(data.LABELS)
        )
        test_xs[index] = x
        test_ys[index] = y

Data Augmentation(データ拡張)は行っていません。

通常ならば画像を反転や回転、切り取りを行いつつデータ拡張を行うところですが、今回は分類する対象がゲームのスクリーンショットで、そのような入力が想定されないため行っていません。

学習する

先ほど作成した HDF5 形式のファイルを読み込んで学習を行います。

まずは学習に必要な定数を定義します。

data.py
# SavedModel の保存先
MODEL_DIR = "data/savedmodel/"
# バッチサイズ
BATCH_SIZE = 50
# 学習画像枚数
TRAIN_SIZE = 47350
# テスト画像枚数
TEST_SIZE = 5000
# 学習のバッチの数
TRAIN_BATCH_COUNT = TRAIN_SIZE // BATCH_SIZE
# テストのバッチの数
TEST_BATCH_COUNT = TEST_SIZE // BATCH_SIZE

学習およびテストに必要なバッチを提供するジェネレータ関数を持つクラスを作ります。

data.py
import h5py

# 定数略

class Data:
    def __init__(self):
        # HDF5ファイルを開き、データセットを読み込む
        h = h5py.File(DATASET_PATH, "r")
        self.train_xs = h[DATASET_TRAIN_XS]
        self.train_ys = h[DATASET_TRAIN_YS]
        self.test_xs = h[DATASET_TEST_XS]
        self.test_ys = h[DATASET_TEST_YS]

    def generator(self):
        "訓練データのジェネレータ"
        batch_index = 0
        while True:
            xs = self.train_xs[
                batch_index * BATCH_SIZE : (batch_index + 1) * BATCH_SIZE
            ]
            ys = self.train_ys[
                batch_index * BATCH_SIZE : (batch_index + 1) * BATCH_SIZE
            ]
            yield (xs, ys)
            batch_index += 1
            # 訓練データをすべて使い切ったら、最初からやり直す
            if batch_index >= TRAIN_BATCH_COUNT:
                batch_index = 0

    def generator_validation_data(self):
        "テストデータのジェネレータ"
        batch_index = 0
        while True:
            xs = self.test_xs[
                batch_index * BATCH_SIZE : (batch_index + 1) * BATCH_SIZE
            ]
            ys = self.test_ys[
                batch_index * BATCH_SIZE : (batch_index + 1) * BATCH_SIZE
            ]
            yield (xs, ys)
            batch_index += 1
            # テストデータをすべて使い切ったら、最初からやり直す
            if batch_index >= TEST_BATCH_COUNT:
                batch_index = 0

そのクラスを使い学習を行います。

tf.keras.applications.EfficientNetV2B2 でモデルを作るときの引数について、 classes が1000以外の時は weights=None にする必要があります。

tf.keras.Model クラスの fit 関数の使い方について、訓練データのジェネレータ関数は引数 x に渡し、テストデータのジェネレータ関数は 引数 validation_data に渡します。以前は fit_generator というジェネレータ関数を渡すための関数がありましたが fit 関数と統合する形で非推奨となりました。

モデルの保存は tf.keras.saving.save_model 関数で SavedModel 形式で行います。後述しますが Keras 形式から TensorFlow.js 向けのモデルに変換すると、ブラウザで読み込んだ時にエラーになりました。

train.py
import tensorflow as tf
import data
from data import Data


g = Data()
# EfficientNetV2B2 を使う
model: tf.keras.Model = tf.keras.applications.EfficientNetV2B2(
    input_shape=(
        data.INPUT_IMAGE_SIZE,
        data.INPUT_IMAGE_SIZE,
        data.INPUT_IMAGE_CHANNEL,
    ),
    weights=None,  # type: ignore
    classes=5,
)
model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"],
)


class Callback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        "各エポック終了時にモデルを保存する"
        tf.keras.saving.save_model(model, data.MODEL_DIR)


cb = Callback()
initial_epoch = 0
model.fit(
    x=g.generator(),
    validation_data=g.generator_validation_data(),
    validation_steps=data.TEST_BATCH_COUNT,
    callbacks=[cb],
    steps_per_epoch=2 * data.TRAIN_BATCH_COUNT,
    epochs=10,
    initial_epoch=initial_epoch,
)

混同行列を表示する

テストデータに対してどのくらいの認識精度が出るか混同行列を出力して確認します。scikit-learn ライブラリの confusion_matrix 関数で混同行列を計算します。結果は PandasDataFrame で列と行にラベルをつけることで見やすくします。

predict.py
import data
from data import Data
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

g = Data()
# モデルを SavedModel から読み込む
model: tf.keras.Model = tf.keras.saving.load_model(data.MODEL_DIR)  # type: ignore

# テストデータのジェネレータを取得
generator = g.generator_validation_data()
# すべてのテストデータの正解と予測結果を格納する配列
yss_true = np.zeros((data.TEST_SIZE, len(data.LABELS)), dtype=np.float32)
yss_pred = np.zeros((data.TEST_SIZE, len(data.LABELS)), dtype=np.float32)
# バッチごとに予測する
for batch_index in range(data.TEST_BATCH_COUNT):
    xs, ys_true = next(generator)
    ys_pred = model.predict(xs)
    # 正解を格納する
    yss_true[
        batch_index * data.BATCH_SIZE : (batch_index + 1) * data.BATCH_SIZE
    ] = ys_true
    # 予測結果を格納する
    yss_pred[
        batch_index * data.BATCH_SIZE : (batch_index + 1) * data.BATCH_SIZE
    ] = ys_pred
# 混同行列を作成する
cm = confusion_matrix(yss_true.argmax(axis=1), yss_pred.argmax(axis=1))
# Pandas で混同行列にラベルをつける
df = pd.DataFrame(cm, index=data.LABELS, columns=data.LABELS)
print("Confusion Matrix:")
print(df)

出力された混合行列はこちらです。

Confusion Matrix:
       start   end  kill  death  other
start    502     0     0      0      0
end        0  2251     0      0     12
kill       0     0   138      0      3
death      0     0     4    587     10
other      1     3     7      3   1479

death ラベルがついた画像は601枚中、587枚が death として分類され、4枚が kill 、10枚が other として分類されることが分かりました。

SavedModel 形式のモデルを TensorFlow.js 向けのモデルに変換する

tensorflowjs パッケージをインストールすると使える tensorflowjs_converter コマンドを使い、SavedModel 形式のモデルを TensorFlow.js 向けのモデルに変換します。SavedModel の保存先と TensorFlow.js 向けモデルの

tensorflowjs_converter  --input_format=tf_saved_model data/savedmodel data/jsmodel

TensorFlow.js 向けのモデルが出力されました。

ls data/jsmodel
group1-shard1of9.bin  group1-shard3of9.bin  group1-shard5of9.bin  group1-shard7of9.bin  group1-shard9of9.bin
group1-shard2of9.bin  group1-shard4of9.bin  group1-shard6of9.bin  group1-shard8of9.bin  model.json

Keras モデルは EfficientNetV2 については変換できない

Keras モデルを TensorFlow.js にインポートする」という公式の説明がありますが、こちらは EfficientNetV2 については Keras 形式から変換できません。ブラウザでモデルを読み込ませると

Error: Unknown layer: Normalization.

エラーが表示されます。

TensorFlow GraphDef ベースのモデルを TensorFlow.js にインポートする」 を参考に SavedModel 形式から変換する必要があります。

Web アプリから TensorFlow.js 向けモデルを使う。

まず先ほど出力した TensorFlow.js 向けモデル(group1-shard1of9.bin ~ group1-shard9of9.bin, model.json) は HTTPS で配信できるようにします。

JavaScript の tfjs ライブラリを script タグで読み込みます。

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.12.0/dist/tf.min.js"></script>

JavaScript で tf.loadGraphModel 関数を呼び出してモデルを読み込みます。

predict.js
// 分類モデル
var myModel = null;
/// モデルを読み込む
async function loadImageClassification() {
    // 今回はモデル配信元を相対で指定している。
    myModel = await tf.loadGraphModel('model/model.json');
    return 0;
}

画像分類は ImageElement に対してこのように行います。ImageElement はあらかじめ 224×224 の大きさにする必要があります。

predict.js
// 分類する
async function classify(image) {
    // ImageElement を Tensor に変換する
    const tensor = await tf.browser.fromPixelsAsync(image);
    // Tensor を入力層に合わせて Rehape する
    const x = tensor.reshape([1, 224, 224, 3]).cast('float32');
    // モデルを実行することで出力層の Tensor を得る
    const y = myModel.execute(x);
    // 出力層で最大の値のインデックスが分類結果になる。
    // JavaScript の配列に変換して使用する。
    const array = await y.argMax(axis = 1).array();
    return array[0];
}

まとめ

この記事では Keras の EfficientNetV2 で独自の画像分類モデルを作成して TensorFlow.js で動作させる方法を解説しました。TensorFlow.js が変換元として対応している TensorFlow のモデル保存形式として Keras 形式SavedModel 形式がありますが、EfficientNetV2 については SavedModel 形式しか対応していません。
また、TensorFlow.js の API を使い、画像をモデルの入力層に渡して出力層を得て、そこから分類結果を得る方法を解説しました。

5
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?