はじめに
@zacky1972 さんが、 NxIREE なるものについて調べてほしいと零していたので、動かしてみました
実装したノートブックはこちら
IREE とは
IREE = Intermediate Representation Execution Environment = 中間表現実行環境
イーリーと読みます
機械学習モデルを CPU, GPU, TPU など、様々なハードウェア上で実行できる環境を提供します
XLA が TensorFlow 専用なのに対し、 IREE は PyTorch など他のフレームワークにも対応しています
また、 IREE は CPU, GPT, TPU 以外のモバイルデバイスにも対応しています
NxIREE の目的
NxIREE の README から引用
Companion library to EXLA, providing bindings for the IREE runtime for MLIR.
和訳
EXLAの補助ライブラリであり、MLIRのIREEランタイムのバインディングを提供します
MLIR = Multi-Level Intermediate Representation = 多段階中間表現
MLIR は以下のように利用されます
- Python などの高レベルコード(人間が書いたコード)を中間表現に変換する
- 効率的に実行できるよう最適化する
- 低レベルコード(機械が実行するコード)に変換する
EXLA を様々な環境で実行できるようにするためのモジュールということのようです
実行環境
Apple Silicon の Mac にしか対応していなさそうなので、 M2Mac で実行しました
- MacBook Pro 13 インチ M2 2022
- macOS Sonoma 14.5
- Erlang 27.0.1
- Elixir 1.17.2-otp-27
Erlang と Elixir は asdf でインストール
iex での実行
リポジトリーにある "run.exs" を動かしてみます
NxIREE.list_drivers() |> IO.inspect(label: "drivers")
{:ok, [dev | _]} = NxIREE.list_devices("metal") |> IO.inspect()
# Obtained by using EXLA.to_mlir_module(fn a, b -> Nx.add(Nx.cos(a), Nx.sin(b)) end, [Nx.template({4}, :f32), Nx.template({4}, :s64)])
mlir_module = """
module {
func.func public @main(%arg0: tensor<4xf32>, %arg1: tensor<4xi64>) -> tensor<4xf32> {
%0 = stablehlo.cosine %arg0 : tensor<4xf32>
%1 = stablehlo.convert %arg1 : (tensor<4xi64>) -> tensor<4xf32>
%2 = stablehlo.sine %1 : tensor<4xf32>
%3 = stablehlo.add %0, %2 : tensor<4xf32>
return %3 : tensor<4xf32>
}
}
"""
# flags = ["--iree-hal-target-backends=cuda", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
flags = ["--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
%NxIREE.Module{} = module = NxIREE.compile(mlir_module, flags)
arg0 = Nx.tensor([1.0, 2.0, 3.0, 4.0])
arg1 = Nx.tensor([1, -1, 1, -1])
IO.gets("Press enter to continue - #{System.pid()}")
{:ok, [result]} = NxIREE.call(module, [arg0, arg1], device: dev) |> IO.inspect()
IO.inspect(result, limit: 4)
IO.gets("Press enter to finish")
リポジトリーのクローン
git clone https://github.com/elixir-nx/nx_iree
cd nx_iree
依存モジュールのインストール
mix deps.get
コンパイル
コンパイルには Ninja が必要です
未インストールの場合はインストールします
brew install ninja
コンパイルを実行します
mix compile
リネーム
"run.exs" を ".iex.exs" にリネームします
".iex.exs" は iex -S mix
実行時に自動実行されます
処理の実行
$ iex -S mix
実行結果
drivers: {:ok,
%{
"local-sync" => "Local execution using a lightweight inline synchronous queue",
"local-task" => "Local execution using the IREE multithreading task system",
"metal" => "Apple Metal"
}}
{:ok, ["metal://default", "metal://00000001000003e9"]}
Press enter to continue - 52492
{:ok,
[
#Nx.Tensor<
f32[4]
NxIREE.Tensor(metal://default)
[1.3817732334136963, -1.257617712020874, -0.1485215425491333, -1.4951145648956299]
>
]}
#Nx.Tensor<
f32[4]
NxIREE.Tensor(metal://default)
[1.3817732334136963, -1.257617712020874, -0.1485215425491333, -1.4951145648956299]
>
Press enter to finish
とりあえず実行できました
各コードの意味は Livebook で動かしながら確認してみましょう
Livebook での実行
"run.exs" の内容を Livebook で個別に実行し、処理の内容を理解していきます
モジュールのインストール
まだ Hex に公開されていないので、 GitHub から NxIREE をインストールします
Mix.install([
{:nx_iree, "~> 0.1", git: "https://github.com/elixir-nx/nx_iree"}
])
デバイスの取得
デバイスの一覧を取得します
NxIREE.list_drivers()
|> elem(1)
実行結果
%{
"local-sync" => "Local execution using a lightweight inline synchronous queue",
"local-task" => "Local execution using the IREE multithreading task system",
"metal" => "Apple Metal"
}
M2Mac なので Metal が使えるようです
Metal は iPhone や MacBook など、 Apple 製品で GPU を操作するための API です
IREE によって Metal を動かすための低レベルコードを生成できるようです
実際に使用するデバイスを取得します
dev =
NxIREE.list_devices("metal")
|> elem(1)
|> hd()
実行結果
"metal://default"
中間表現のコンパイル
MLIR によって記述された行列演算処理を用意します
コメントにある通り、このコードは EXLA.to_mlir_module
で Elixir の高レベルコードを中間表現として取得したものです
# Obtained by using EXLA.to_mlir_module(fn a, b -> Nx.add(Nx.cos(a), Nx.sin(b)) end, [Nx.template({4}, :f32), Nx.template({4}, :s64)])
mlir_module = """
module {
func.func public @main(%arg0: tensor<4xf32>, %arg1: tensor<4xi64>) -> tensor<4xf32> {
%0 = stablehlo.cosine %arg0 : tensor<4xf32>
%1 = stablehlo.convert %arg1 : (tensor<4xi64>) -> tensor<4xf32>
%2 = stablehlo.sine %1 : tensor<4xf32>
%3 = stablehlo.add %0, %2 : tensor<4xf32>
return %3 : tensor<4xf32>
}
}
"""
中間表現の中に出てくる stablehlo.cosine
などの関数についてはこちらで確認できます
StableHLO という機械学習用の高レベル操作を集めたもののようです
中間表現を Metal 用にコンパイルします
コンパイル時のフラグ設定
flags = [
"--iree-hal-target-backends=metal-spirv",
"--iree-input-type=stablehlo_xla",
"--iree-execution-model=async-internal"
]
コンパイルの実行
module = NxIREE.compile(mlir_module, flags)
演算処理の実行
コンパイルした処理を実行します
arg0 = Nx.tensor([1.0, 2.0, 3.0, 4.0])
arg1 = Nx.tensor([1, -1, 1, -1])
NxIREE.call(module, [arg0, arg1], device: dev)
実行結果
{:ok,
[
#Nx.Tensor<
f32[4]
NxIREE.Tensor(metal://default)
[1.3817732334136963, -1.257617712020874, -0.1485215425491333, -1.4951145648956299]
>
]}
Nx での演算結果と比較
Nx で演算した場合と同じになるか確認します
nx_function = fn a, b ->
Nx.add(Nx.cos(a), Nx.sin(b))
end
nx_function.(arg0, arg1)
#Nx.Tensor<
f32[4]
[1.3817732334136963, -1.2576178312301636, -0.1485215425491333, -1.4951145648956299]
>
微妙に誤差はありますが、ほぼ同じ結果になっています
Nx の演算を Metal 上で実行できたようです
Elixir コードからの変換
EXLA の最新版には EXLA.to_mlir_module
が含まれているため、これを利用して Elixir コードから中間表現が生成できます
モジュールのインストール
EXLA と Nx も GitHub から取得し、速度比較のために Benchee 、 ランダム数値(正規分布)生成のために Statistics をインストールします
Mix.install([
{:exla, "~> 0.7", github: "elixir-nx/nx", sparse: "exla"},
{:nx, "~> 0.7", github: "elixir-nx/nx", sparse: "nx", override: true},
{:nx_iree, "~> 0.1", git: "https://github.com/elixir-nx/nx_iree"},
{:kino, "~> 0.13"},
{:benchee, "~> 1.3"},
{:statistics, "~> 0.6"}
])
中間表現の取得
Softmax 関数の中間表現を生成してみます
softmax = fn tensor ->
Nx.divide(
Nx.exp(tensor),
Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
)
end
input = Nx.tensor([1.0, 2.0, -1.0, -1.5])
args = [input]
mlir_module = EXLA.to_mlir_module(softmax, args)
Kino.Text.new(mlir_module)
実行結果
module {
func.func public @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%0 = stablehlo.exponential %arg0 : tensor<4xf32>
%1 = stablehlo.exponential %arg0 : tensor<4xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%2 = stablehlo.reduce(%1 init: %cst) across dimensions = [0] : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) {
%6 = stablehlo.add %arg2, %arg1 : tensor<f32>
stablehlo.return %6 : tensor<f32>
}
%3 = stablehlo.reshape %2 : (tensor<f32>) -> tensor<1xf32>
%4 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<1xf32>) -> tensor<4xf32>
%5 = stablehlo.divide %0, %4 : tensor<4xf32>
return %5 : tensor<4xf32>
}
}
何となく、ちゃんとできていそうな気がします
コンパイル
デバイスの取得
dev =
NxIREE.list_devices("metal")
|> elem(1)
|> hd()
フラグ設定
flags = [
"--iree-hal-target-backends=metal-spirv",
"--iree-input-type=stablehlo_xla",
"--iree-execution-model=async-internal"
]
module = NxIREE.compile(mlir_module, flags)
演算処理の実行
NxIREE.call(module, args, device: dev)
実行結果
{:ok,
[
#Nx.Tensor<
f32[4]
NxIREE.Tensor(metal://default)
[0.2540842592716217, 0.6906726360321045, 0.034386567771434784, 0.020856507122516632]
>
]}
直接 Softmax を実行してみます
softmax.(input)
実行結果
#Nx.Tensor<
f32[4]
[0.2540842592716217, 0.6906726956367493, 0.034386567771434784, 0.02085650898516178]
>
誤差はあるものの、ほぼ同じ結果です
速度比較
Nx.BinaryBackend の場合と EXLA.Backend の場合、 NxIREE の場合で速度比較してみます
exla_input = Nx.backend_transfer(input, EXLA.Backend)
Benchee.run(%{
"nx" => fn -> softmax.(input) end,
"exla" => fn -> softmax.(exla_input) end,
"nx_iree" => fn -> NxIREE.call(module, [input], device: dev) end
})
実行結果
Name ips average deviation median 99th %
nx 265.49 K 3.77 μs ±257.11% 3.46 μs 6.13 μs
exla 14.52 K 68.88 μs ±43.40% 66.37 μs 104.92 μs
nx_iree 1.79 K 557.50 μs ±14.71% 550.14 μs 860.11 μs
Comparison:
nx 265.49 K
exla 14.52 K - 18.29x slower +65.11 μs
nx_iree 1.79 K - 148.01x slower +553.74 μs
予測に反して Nx.BinaryBackend が速く、 NxIREE が遅くなってしまいました
恐らく行列が小さすぎることが問題なので、行列を大きくしてみます
長さ 100,000 のテンソルを用意しました
input =
1..100000
|> Enum.map(fn _ -> Statistics.Distributions.Normal.rand() end)
|> Nx.tensor()
args = [input]
mlir_module = EXLA.to_mlir_module(softmax, args)
Kino.Text.new(mlir_module)
実行結果
module {
func.func public @main(%arg0: tensor<100000xf32>) -> tensor<100000xf32> {
%0 = stablehlo.exponential %arg0 : tensor<100000xf32>
%1 = stablehlo.exponential %arg0 : tensor<100000xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%2 = stablehlo.reduce(%1 init: %cst) across dimensions = [0] : (tensor<100000xf32>, tensor<f32>) -> tensor<f32>
reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) {
%6 = stablehlo.add %arg2, %arg1 : tensor<f32>
stablehlo.return %6 : tensor<f32>
}
%3 = stablehlo.reshape %2 : (tensor<f32>) -> tensor<1xf32>
%4 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<1xf32>) -> tensor<100000xf32>
%5 = stablehlo.divide %0, %4 : tensor<100000xf32>
return %5 : tensor<100000xf32>
}
}
コンパイルします
module = NxIREE.compile(mlir_module, flags)
改めて速度比較してみます
Benchee.run(%{
"nx" => fn -> softmax.(input) end,
"exla" => fn -> softmax.(exla_input) end,
"nx_iree" => fn -> NxIREE.call(module, [input], device: dev) end
})
実行結果
Name ips average deviation median 99th %
exla 14.29 K 70.00 μs ±26.28% 67.71 μs 101.50 μs
nx_iree 1.13 K 886.81 μs ±22.57% 854.91 μs 1675.55 μs
nx 0.0232 K 43122.48 μs ±3.44% 42999.11 μs 52094.74 μs
Comparison:
exla 14.29 K
nx_iree 1.13 K - 12.67x slower +816.81 μs
nx 0.0232 K - 616.08x slower +43052.48 μs
EXLA が最も速いという結果になりましたが、 BiraryBackend と比べれば IREE もかなり高速です
まとめ
NxIREE を使うことで、様々な環境で高速に演算処理が実行できそうです
まだ未リリースなので、これからの発展に期待ですね