9
0

はじめに

ONNX 形式のモデルを読み込むのに以前 Axon ONNX を使用しました

Axon ONNX で物体検出を実行した記事はこちら

@piacerex さんによる画像分類の記事はこちら

しかし、 Axon ONNX の README には以下のような記述があります

See Ortex which provides full-blown compatibility for running ONNX models via ONNX Runtime bindings.

Chat-GPT による和訳

Ortexは、ONNX Runtimeバインディングを通じてONNXモデルを完全に実行できる互換性を提供します。

というわけで、 Ortex を使ってみます

実装したノートブックはこちら

画像分類

モジュールのインストール

必要なモジュールをインストールします

Mix.install(
  [
    {:exla, "~> 0.7"},
    {:stb_image, "~> 0.6"},
    {:req, "~> 0.5"},
    {:kino, "~> 0.13"},
    {:ortex, "~> 0.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

StbImage は画像読込やリサイズに使用していますが、 Evision などで代用可能です

クラス一覧の取得

今回分類したいクラスの一覧を Web からダウンロードしてきます

classes =
  "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
  |> Req.get!()
  |> Map.get(:body)

モデルの読込

ONNX のモデルをダウンロードし、読み込みます

model_path = "/tmp/resnet.onnx"

unless File.exists?(model_path) do
  "https://media.githubusercontent.com/media/onnx/models/main/validated/vision/classification/resnet/model/resnet18-v1-7.onnx?download=true"
  |> Req.get!(connect_options: [timeout: 300_000], into: File.stream!(model_path))
end

model = Ortex.load(model_path)

実行結果(モデルの情報)では入出力形式を確認できます

#Ortex.Model<
  inputs: [{"data", "Float32", [nil, 3, 224, 224]}]
  outputs: [{"resnetv15_dense0_fwd", "Float32", [nil, 1000]}]>

Nx.Serving で AI 処理を提供できるように準備します

serving = Nx.Serving.new(Ortex.Serving, model)

画像の読込

画像を Web からダウンロードして読み込みます

ONNX モデルの入力に合わせて 224 * 224 にリサイズしておきます

img_path = "/tmp/shark.jpg"

img_tensor =
  "https://www.collinsdictionary.com/images/full/greatwhiteshark_157273892.jpg"
  |> Req.get!()
  |> then(&StbImage.read_binary!(&1.body))
  |> StbImage.resize(224, 224)
  |> StbImage.to_nx()

Kino.Image.new(img_tensor)

実行結果

shark.png

画像のチャネル数を取得します

RGB (透過なし)なら 3 、 RGBA (透過あり)なら 4 が返ってきます

nx_channels = Nx.axis_size(img_tensor, 2)

透過ありの場合は A チャネルを切り離します(今回のサメ画像の場合はチャネル数 3 なのでそのまま)

RGB の値を正規化、標準化し、色の次元が先頭になるよう転置します

img_tensor =
  case nx_channels do
    3 -> img_tensor
    4 -> Nx.slice(img_tensor, [0, 0, 0], [224, 224, 3])
  end
  |> Nx.divide(255)
  |> Nx.subtract(Nx.tensor([0.485, 0.456, 0.406]))
  |> Nx.divide(Nx.tensor([0.229, 0.224, 0.225]))
  |> Nx.transpose()

画像データをバッチに変換します

batch = Nx.Batch.stack([img_tensor])

推論

Nx.Serving.run で推論します

serving
|> Nx.Serving.run(batch)
|> Nx.backend_transfer()
|> elem(0)
|> Nx.flatten()
|> Nx.argsort()
|> Nx.reverse()
|> Nx.slice([0], [5])
|> Nx.to_flat_list()
|> Enum.with_index()
|> Enum.map(fn {no, index} -> {index, Map.get(classes, to_string(no))} end)
|> dbg()

Nx.Serving.run の結果は以下のような形式になっているため、 Nx.BinaryBackend に変換してから Tuple の先頭を取得しています

Nx.argsortOrtex.Backend では未実装なので Nx.BinaryBackend に変換する必要があります

{#Nx.Tensor<
   f32[1][1000]
   Ortex.Backend
   [
     [5.107775688171387, 7.190211296081543, 17.523210525512695, ...]
   ]
 >}

その後は確信度上位5つを取得してからクラス名に変換しています

最終的な実行結果は以下のようになります

[
  {0, ["n01484850", "great_white_shark"]},
  {1, ["n01491361", "tiger_shark"]},
  {2, ["n01494475", "hammerhead"]},
  {3, ["n01496331", "electric_ray"]},
  {4, ["n02074367", "dugong"]}
]

Axon ONNX の場合との差分

インストールするモジュールを ortex に変更します

-    {:axon_onnx, "~> 0.4", git: "https://github.com/mortont/axon_onnx/"}
+    {:ortex, "~> 0.1"}

モデルの読み込み時、 Nx.Serving を用意します

- {model, params} = AxonOnnx.import(model_path)
+ model = Ortex.load(model_path)
+ serving = Nx.Serving.new(Ortex.Serving, model)

画像をバッチに変換します

img_tensor =
  case nx_channels do
    3 -> img_tensor
    4 -> Nx.slice(img_tensor, [0, 0, 0], [224, 224, 3])
  end
  |> Nx.divide(255)
  |> Nx.subtract(Nx.tensor([0.485, 0.456, 0.406]))
  |> Nx.divide(Nx.tensor([0.229, 0.224, 0.225]))
  |> Nx.transpose()
- |> Nx.new_axis(0)
+
+ batch = Nx.Batch.stack([img_tensor])

推論結果は OrtexBackend の Tuple になっています

- model
- |> Axon.predict(params, img_tensor)
+ serving
+ |> Nx.Serving.run(batch)
+ |> Nx.backend_transfer()
+ |> elem(0)

物体検出の場合

上記の Axon ONNX の場合との差分と同様に Axon ONNX の場合のコードを書き換えれば良いだけです

Mix.install(
  [
    {:exla, "~> 0.7"},
    {:stb_image, "~> 0.6"},
    {:req, "~> 0.5"},
    {:kino, "~> 0.13"},
    {:ortex, "~> 0.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)
classes =
  "https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
  |> Req.get!()
  |> then(&String.split(&1.body, "\n"))
  |> Enum.map(&String.trim(&1))
  |> Enum.filter(&(String.length(&1) > 0))

num_classes = Enum.count(classes)

model_path = "/tmp/yolov2.onnx"

unless File.exists?(model_path) do
  "https://media.githubusercontent.com/media/onnx/models/main/validated/vision/object_detection_segmentation/yolov2-coco/model/yolov2-coco-9.onnx?download=true"
  |> Req.get!(connect_options: [timeout: 300_000], into: File.stream!(model_path))
end

model = Ortex.load(model_path)

serving = Nx.Serving.new(Ortex.Serving, model)

anchors =
  Nx.tensor([
    [0.57273, 0.677385],
    [1.87446, 2.06253],
    [3.33843, 5.47434],
    [7.88282, 3.52778],
    [9.77052, 9.16828]
  ])

num_anchors =
  anchors
  |> Nx.shape()
  |> elem(0)

anchors_tensor = Nx.reshape(anchors, {1, 1, 1, num_anchors, 2})
img_tensor =
  "https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg"
  |> Req.get!()
  |> then(&StbImage.read_binary!(&1.body))
  |> StbImage.resize(416, 416)
  |> StbImage.to_nx()

nx_channels = Nx.axis_size(img_tensor, 2)

img_tensor =
  case nx_channels do
    3 -> img_tensor
    4 -> Nx.slice(img_tensor, [0, 0, 0], [416, 416, 3])
  end
  |> Nx.divide(255)
  |> Nx.transpose(axes: [2, 0, 1])

batch = Nx.Batch.stack([img_tensor])

feats =
  serving
  |> Nx.Serving.run(batch)
  |> Nx.backend_transfer()
  |> elem(0)

これ以降は Axon ONNX の場合と同様です

まとめ

Axon ONNX を使ったコードから一定のルールで変換可能であるため、すぐに実装できました

今後はこちらをメインに使っていこうと思います

9
0
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
0