15
2

More than 1 year has passed since last update.

Livebook で Bumblebee の画像分類を実行する

Last updated at Posted at 2022-12-13

はじめに

前回の記事で Stable Diffusion を実行しました

せっかくなので、公式のノートブックに載っている他のモデルも動かしてみましょう

今回は ResNet50 などの画像分類モデルを使って、画像に写っているものを識別します

Bumblebee が前処理や後処理をやってくれているので、 Python よりも簡単に AI を使うことができます

このシリーズの記事

実装の全文はこちら

実行環境

  • MacBook Pro 13 inchi
    • 2.4 GHz クアッドコアIntel Core i5
    • 16 GB 2133 MHz LPDDR3
  • macOS Ventura 13.0.1
  • Rancher Desktop 1.6.2
    • メモリ割り当て 12 GB
    • CPU 割り当て 6 コア

Livebook 0.8.0 の Docker イメージを元にしたコンテナで動かしました

コンテナ定義はこちらを参照

セットアップ

必要なモジュールをインストールして EXLA.Backend で Nx が動くようにします

Mix.install(
  [
    {:bumblebee, "~> 0.1"},
    {:nx, "~> 0.4"},
    {:exla, "~> 0.4"},
    {:kino, "~> 0.8"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

EXLA.Backend で動いていることを確認します

Nx.default_backend()

コンテナで動かしている場合、キャッシュディレクトリーを指定した方が都合がいいです

※詳細は前回の記事を見てください

cache_dir = "/tmp/bumblebee_cache"

モデルのダウンロード

Stable Diffusion のときと同様、モデルファイルを Haggin Face からダウンロードしてきて読み込みます

必要な場合は cache_dir を指定します

{:ok, resnet} =
  Bumblebee.load_model({
    :hf,
    "microsoft/resnet-50",
    cache_dir: cache_dir
  })
{:ok, featurizer} =
  Bumblebee.load_featurizer({
    :hf,
    "microsoft/resnet-50",
    cache_dir: cache_dir
  })

featurizer はモデルに画像を読み込ませるための前処理定義です

ここで指定している ResNet 50 の場合、 https://huggingface.co/microsoft/resnet-50/resolve/main/preprocessor_config.json を読むことになります

preprocessor_config.json は以下のような内容です

これに従って、例えば画像を 224 * 224 にリサイズしたりします

{
  "crop_pct": 0.875,
  "do_normalize": true,
  "do_resize": true,
  "feature_extractor_type": "ConvNextFeatureExtractor",
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 3,
  "size": 224
}

画像分類の実行

画像の準備

Kino.Input.image で分類する画像を用意します

セルを実行して表示される領域に画像ファイルをドラッグ&ドロップしてください

image_input = Kino.Input.image("IMAGE", size: {224, 224})

スクリーンショット 2022-12-13 14.23.03.png

入力した画像を Nx のテンソルとして読み込みます

image =
  image_input
  |> Kino.Input.read()
  |> then(fn input ->
    input.data
    |> Nx.from_binary(:u8)
    |> Nx.reshape({input.height, input.width, 3})
  end)

Kino.Image.new(image)

スクリーンショット 2022-12-13 14.24.15.png

手動推論

公式とは順序が逆になりますが、先に手動推論をやってみましょう

まず画像に前処理を掛けます

inputs = Bumblebee.apply_featurizer(featurizer, image)

1 * 224 * 224 * 3 のテンソルになりました

スクリーンショット 2022-12-13 14.35.03.png

このテンソルを Axon で推論します

outputs = Axon.predict(resnet.model, resnet.params, inputs)

スクリーンショット 2022-12-13 14.36.13.png

推論結果が 1 * 1000 のテンソルで得られました

今回使っているモデルは画像を 1000 クラスに分類するので、各クラス毎の確信度が得られています

このままの数値は使いづらいので softmax して、 その中でも TOP5 だけを取り出し、スコアとクラス名(ラベル)に変換します

最後に結果が見やすいようにデータテーブルで表示します

outputs.logits
|> Nx.squeeze()
|> Axon.Activations.softmax()
|> Bumblebee.Utils.Nx.top_k(k: 5)
|> then(fn {scores, class_ids} ->
  scores
  |> Nx.to_flat_list()
  |> Enum.zip(Nx.to_flat_list(class_ids))
  |> Enum.map(fn {score, class_id} ->
    [
      label: resnet.spec.id_to_label[class_id],
      score: score
    ]
  end)
end)
|> Kino.DataTable.new()
|> dbg()

スクリーンショット 2022-12-13 14.40.22.png

ここまでの実装を見て分かる通り、前処理と後処理だけでも結構めんどくさいです

というわけで、 Bumblebee は勝手に前処理と後処理をやってくれます

Nx.Serving による提供

こっちが本来の Bumblebee の使い方です

serving = Bumblebee.Vision.image_classification(resnet, featurizer)
serving
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))

スクリーンショット 2022-12-13 14.43.34.png

何も考えず Nx.Serving.run に渡せば predictions の中に結果が返ってきます

すごく便利ですね

他のモデル

他のモデルもリポジトリーIDを書き換えるだけで実行できる、とのことなので、関数化して実行してみます

serve_model = fn repository_id ->
  {:ok, model} =
    Bumblebee.load_model({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  {:ok, featurizer} =
    Bumblebee.load_featurizer({
      :hf,
      repository_id,
      cache_dir: cache_dir
    })

  Bumblebee.Vision.image_classification(model, featurizer)
end
"facebook/convnext-tiny-224"
|> serve_model.()
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))

スクリーンショット 2022-12-13 14.46.09.png

"google/vit-base-patch16-224"
|> serve_model.()
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))

スクリーンショット 2022-12-13 14.46.41.png

"facebook/deit-base-distilled-patch16-224"
|> serve_model.()
|> Nx.Serving.run(image)
|> then(&Kino.DataTable.new(&1.predictions))

スクリーンショット 2022-12-13 14.47.17.png

モデルによって推論結果が異なります

まとめ

画像分類モデルが簡単に実行できますね!

YOLO は後処理が面倒すぎるので、まだ Bumblebee には入らなさそう

15
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
2