はじめに
ONNX 形式のモデルを読み込むのに以前 Axon ONNX を使用しました
しかし、 Axon ONNX の README には以下のような記述があります
See Ortex which provides full-blown compatibility for running ONNX models via ONNX Runtime bindings.
Chat-GPT による和訳
Ortexは、ONNX Runtimeバインディングを通じてONNXモデルを完全に実行できる互換性を提供します。
というわけで、 Ortex を使ってみます
実装したノートブックはこちら
画像分類
モジュールのインストール
必要なモジュールをインストールします
Mix.install(
[
{:exla, "~> 0.7"},
{:stb_image, "~> 0.6"},
{:req, "~> 0.5"},
{:kino, "~> 0.13"},
{:ortex, "~> 0.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
StbImage は画像読込やリサイズに使用していますが、 Evision などで代用可能です
クラス一覧の取得
今回分類したいクラスの一覧を Web からダウンロードしてきます
classes =
"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
|> Req.get!()
|> Map.get(:body)
モデルの読込
ONNX のモデルをダウンロードし、読み込みます
model_path = "/tmp/resnet.onnx"
unless File.exists?(model_path) do
"https://media.githubusercontent.com/media/onnx/models/main/validated/vision/classification/resnet/model/resnet18-v1-7.onnx?download=true"
|> Req.get!(connect_options: [timeout: 300_000], into: File.stream!(model_path))
end
model = Ortex.load(model_path)
実行結果(モデルの情報)では入出力形式を確認できます
#Ortex.Model<
inputs: [{"data", "Float32", [nil, 3, 224, 224]}]
outputs: [{"resnetv15_dense0_fwd", "Float32", [nil, 1000]}]>
Nx.Serving
で AI 処理を提供できるように準備します
serving = Nx.Serving.new(Ortex.Serving, model)
画像の読込
画像を Web からダウンロードして読み込みます
ONNX モデルの入力に合わせて 224 * 224 にリサイズしておきます
img_path = "/tmp/shark.jpg"
img_tensor =
"https://www.collinsdictionary.com/images/full/greatwhiteshark_157273892.jpg"
|> Req.get!()
|> then(&StbImage.read_binary!(&1.body))
|> StbImage.resize(224, 224)
|> StbImage.to_nx()
Kino.Image.new(img_tensor)
実行結果
画像のチャネル数を取得します
RGB (透過なし)なら 3 、 RGBA (透過あり)なら 4 が返ってきます
nx_channels = Nx.axis_size(img_tensor, 2)
透過ありの場合は A チャネルを切り離します(今回のサメ画像の場合はチャネル数 3 なのでそのまま)
RGB の値を正規化、標準化し、色の次元が先頭になるよう転置します
img_tensor =
case nx_channels do
3 -> img_tensor
4 -> Nx.slice(img_tensor, [0, 0, 0], [224, 224, 3])
end
|> Nx.divide(255)
|> Nx.subtract(Nx.tensor([0.485, 0.456, 0.406]))
|> Nx.divide(Nx.tensor([0.229, 0.224, 0.225]))
|> Nx.transpose()
画像データをバッチに変換します
batch = Nx.Batch.stack([img_tensor])
推論
Nx.Serving.run
で推論します
serving
|> Nx.Serving.run(batch)
|> Nx.backend_transfer()
|> elem(0)
|> Nx.flatten()
|> Nx.argsort()
|> Nx.reverse()
|> Nx.slice([0], [5])
|> Nx.to_flat_list()
|> Enum.with_index()
|> Enum.map(fn {no, index} -> {index, Map.get(classes, to_string(no))} end)
|> dbg()
Nx.Serving.run
の結果は以下のような形式になっているため、 Nx.BinaryBackend に変換してから Tuple の先頭を取得しています
Nx.argsort
が Ortex.Backend
では未実装なので Nx.BinaryBackend
に変換する必要があります
{#Nx.Tensor<
f32[1][1000]
Ortex.Backend
[
[5.107775688171387, 7.190211296081543, 17.523210525512695, ...]
]
>}
その後は確信度上位5つを取得してからクラス名に変換しています
最終的な実行結果は以下のようになります
[
{0, ["n01484850", "great_white_shark"]},
{1, ["n01491361", "tiger_shark"]},
{2, ["n01494475", "hammerhead"]},
{3, ["n01496331", "electric_ray"]},
{4, ["n02074367", "dugong"]}
]
Axon ONNX の場合との差分
インストールするモジュールを ortex に変更します
- {:axon_onnx, "~> 0.4", git: "https://github.com/mortont/axon_onnx/"}
+ {:ortex, "~> 0.1"}
モデルの読み込み時、 Nx.Serving
を用意します
- {model, params} = AxonOnnx.import(model_path)
+ model = Ortex.load(model_path)
+ serving = Nx.Serving.new(Ortex.Serving, model)
画像をバッチに変換します
img_tensor =
case nx_channels do
3 -> img_tensor
4 -> Nx.slice(img_tensor, [0, 0, 0], [224, 224, 3])
end
|> Nx.divide(255)
|> Nx.subtract(Nx.tensor([0.485, 0.456, 0.406]))
|> Nx.divide(Nx.tensor([0.229, 0.224, 0.225]))
|> Nx.transpose()
- |> Nx.new_axis(0)
+
+ batch = Nx.Batch.stack([img_tensor])
推論結果は OrtexBackend の Tuple になっています
- model
- |> Axon.predict(params, img_tensor)
+ serving
+ |> Nx.Serving.run(batch)
+ |> Nx.backend_transfer()
+ |> elem(0)
物体検出の場合
上記の Axon ONNX の場合との差分と同様に Axon ONNX の場合のコードを書き換えれば良いだけです
Mix.install(
[
{:exla, "~> 0.7"},
{:stb_image, "~> 0.6"},
{:req, "~> 0.5"},
{:kino, "~> 0.13"},
{:ortex, "~> 0.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
classes =
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|> Req.get!()
|> then(&String.split(&1.body, "\n"))
|> Enum.map(&String.trim(&1))
|> Enum.filter(&(String.length(&1) > 0))
num_classes = Enum.count(classes)
model_path = "/tmp/yolov2.onnx"
unless File.exists?(model_path) do
"https://media.githubusercontent.com/media/onnx/models/main/validated/vision/object_detection_segmentation/yolov2-coco/model/yolov2-coco-9.onnx?download=true"
|> Req.get!(connect_options: [timeout: 300_000], into: File.stream!(model_path))
end
model = Ortex.load(model_path)
serving = Nx.Serving.new(Ortex.Serving, model)
anchors =
Nx.tensor([
[0.57273, 0.677385],
[1.87446, 2.06253],
[3.33843, 5.47434],
[7.88282, 3.52778],
[9.77052, 9.16828]
])
num_anchors =
anchors
|> Nx.shape()
|> elem(0)
anchors_tensor = Nx.reshape(anchors, {1, 1, 1, num_anchors, 2})
img_tensor =
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg"
|> Req.get!()
|> then(&StbImage.read_binary!(&1.body))
|> StbImage.resize(416, 416)
|> StbImage.to_nx()
nx_channels = Nx.axis_size(img_tensor, 2)
img_tensor =
case nx_channels do
3 -> img_tensor
4 -> Nx.slice(img_tensor, [0, 0, 0], [416, 416, 3])
end
|> Nx.divide(255)
|> Nx.transpose(axes: [2, 0, 1])
batch = Nx.Batch.stack([img_tensor])
feats =
serving
|> Nx.Serving.run(batch)
|> Nx.backend_transfer()
|> elem(0)
これ以降は Axon ONNX の場合と同様です
まとめ
Axon ONNX を使ったコードから一定のルールで変換可能であるため、すぐに実装できました
今後はこちらをメインに使っていこうと思います