はじめに
以前の記事でリリース前の NxIREE を使ってみました
その後、 2024/10/04 に Hex へバージョン 0.0.1 がリリースされたため、改めてリリース版を使用します
また、 Nx のバックエンドとして動かせるようになっているので、 NxIREE.Backend で Nx を動かしてみます
今回も実行環境には Livebook を使用します
行列演算の実行
実装したノートブックはこちら
セットアップ
Hex にリリースされたため、 GitHub リポジトリーを指定せずにインストール可能です
Mix.install([
{:nx_iree, "~> 0.0"}
])
デバイスの取得
ドライバーの一覧を取得します
NxIREE.list_drivers()
|> elem(1)
実行結果
%{
"local-sync" => "Local execution using a lightweight inline synchronous queue",
"local-task" => "Local execution using the IREE multithreading task system",
"metal" => "Apple Metal"
}
metal を指定してデバイスを取得します
dev =
NxIREE.list_devices("metal")
|> elem(1)
|> hd()
実行結果
%NxIREE.Device{
ref: #Reference<0.3868306615.1943666690.245036>,
driver_name: "metal",
kind: :io,
id: 4479806464,
uri: "metal://000000010000040f",
compiler_target_backend: "metal-spirv"
}
中間表現のコンパイル
MLIR によって記述された行列演算処理を用意します
mlir_module = """
module {
func.func public @main(%arg0: tensor<4xf32>, %arg1: tensor<4xi32>) -> tensor<4xf32> {
%0 = stablehlo.cosine %arg0 : tensor<4xf32>
%1 = stablehlo.convert %arg1 : (tensor<4xi32>) -> tensor<4xf32>
%2 = stablehlo.sine %1 : tensor<4xf32>
%3 = stablehlo.add %0, %2 : tensor<4xf32>
return %3 : tensor<4xf32>
}
}
"""
Metal 用のコンパイルフラグを定義します
flags = [
"--iree-hal-target-backends=metal-spirv",
"--iree-input-type=stablehlo_xla",
"--iree-execution-model=async-internal"
]
コンパイルを実行します
module = NxIREE.compile(mlir_module, flags, output_container: Nx.template({4}, :f32))
リリース版では output_container として、出力テンプレートを指定するようになりました
行列演算の実行
コンパイルしたモジュールを NxIREE.call で実行します
arg0 = Nx.tensor([1.0, 2.0, 3.0, 4.0])
arg1 = Nx.tensor([1, -1, 1, -1])
NxIREE.call(module, [arg0, arg1], device: dev)
実行結果
{:ok,
#Nx.Tensor<
f32[4]
NxIREE.Backend(metal://000000010000040f)
[1.3817732334136963, -1.257617712020874, -0.1485215425491333, -1.4951145648956299]
>}
Elixir コードからの変換
新しいノートブックで Elixir コードから中間表現への変換を実行します
実装したノートブックはこちら
セットアップ
Mix.install([
{:exla, "~> 0.9"},
{:nx, "~> 0.9"},
{:nx_iree, "~> 0.0"},
{:kino, "~> 0.14"},
{:benchee, "~> 1.3"},
{:statistics, "~> 0.6"}
])
Softmax 関数の変換
Softmax 関数を中間表現に変換します
NxIREE の強みが出るように入力サイズを大きくしています
softmax = fn tensor ->
Nx.divide(
Nx.exp(tensor),
Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
)
end
input =
{1000, 1000, 5}
|> Nx.iota(type: :f32, backend: Nx.BinaryBackend)
|> Nx.divide(1024 * 1024)
args = [input]
%{
mlir_module: mlir_module,
output_container: output_container
} = EXLA.to_mlir_module(softmax, args)
Kino.Text.new(mlir_module)
実行結果
module {
func.func public @main(%arg0: tensor<1000x1000x5xf32>) -> tensor<1000x1000x5xf32> {
%0 = stablehlo.exponential %arg0 : tensor<1000x1000x5xf32>
%1 = stablehlo.exponential %arg0 : tensor<1000x1000x5xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%2 = stablehlo.reduce(%1 init: %cst) across dimensions = [2] : (tensor<1000x1000x5xf32>, tensor<f32>) -> tensor<1000x1000xf32>
reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) {
%6 = stablehlo.add %arg2, %arg1 : tensor<f32>
stablehlo.return %6 : tensor<f32>
}
%3 = stablehlo.reshape %2 : (tensor<1000x1000xf32>) -> tensor<1000x1000x1xf32>
%4 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<1000x1000x1xf32>) -> tensor<1000x1000x5xf32>
%5 = stablehlo.divide %0, %4 : tensor<1000x1000x5xf32>
return %5 : tensor<1000x1000x5xf32>
}
}
リリース前は EXLA.to_mlir_module の出力がそのまま中間表現でしたが、リリースバージョンでは以下の形式になっています
%{
mlir_module: <中間表現>,
output_container: <出力テンプレート>,
used_inputs: <変換中に使用した入力>
}
中間表現のコンパイル
変換した関数をコンパイルします
コンパイル時に変換結果として取得した output_container を指定します
dev =
NxIREE.list_devices("metal")
|> elem(1)
|> hd()
flags = [
"--iree-hal-target-backends=metal-spirv",
"--iree-input-type=stablehlo_xla",
"--iree-execution-model=async-internal"
]
module = NxIREE.compile(mlir_module, flags, output_container: output_container)
コンパイルした関数を呼び出します
NxIREE.call(module, args, device: dev)
実行結果
{:ok,
#Nx.Tensor<
f32[4]
NxIREE.Backend(metal://000000010000040f)
[0.2540842592716217, 0.6906726360321045, 0.034386567771434784, 0.020856507122516632]
>}
速度比較
Nx.BinaryBackend と EXLA.Backend 、そしてコンパイル済中間表現で速度比較してみます
exla_input = Nx.backend_transfer(input, EXLA.Backend)
Benchee.run(%{
"nx" => fn -> softmax.(input) end,
"exla" => fn -> softmax.(exla_input) end,
"nx_iree" => fn -> NxIREE.call(module, [input], device: dev) end
})
標準出力
Name ips average deviation median 99th %
nx_iree 76.37 13.09 ms ±29.84% 12.38 ms 34.71 ms
exla 49.63 20.15 ms ±8.81% 19.93 ms 31.95 ms
nx 0.194 5144.19 ms ±0.00% 5144.19 ms 5144.19 ms
Comparison:
nx_iree 76.37
exla 49.63 - 1.54x slower +7.06 ms
nx 0.194 - 392.84x slower +5131.09 ms
1000x1000x5 の入力サイズでは NxIREE が最速になり、 Nx.BinaryBackend と比べるとおよそ 400 倍の速さになっています
NxIREE.Backend
Nx のバックエンドとして NxIREE を使用します
実装したノートブックはこちら
セットアップ
NxIREE と速度比較用の Benchee をインストールします
Mix.install([
{:nx_iree, "~> 0.0"},
{:benchee, "~> 1.3"}
])
デバイスの取得
Metal デバイスを取得します
dev =
NxIREE.list_devices("metal")
|> elem(1)
|> hd()
Softmax 関数のコンパイル
Nx.Defn.default_options で関数コンパイル時の設定を指定します
flags = [
"--iree-hal-target-backends=metal-spirv",
"--iree-input-type=stablehlo_xla",
"--iree-execution-model=async-internal"
]
Nx.Defn.default_options(
compiler: NxIREE.Compiler,
iree_compiler_flags: flags,
iree_runtime_options: [device: dev]
)
Softmax 関数を定義します
softmax = fn tensor ->
Nx.divide(
Nx.exp(tensor),
Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
)
end
入力として、 1000x1000x5 の行列を NxIREE.Backend で作成します
iree_input =
{1000, 1000, 5}
|> Nx.iota(type: :f32, backend: NxIREE.Backend)
|> Nx.divide(1024 * 1024)
そのまま Softmax 関数を実行してみます
softmax.(iree_input)
実行結果
#Nx.Tensor<
f32[1000][1000][5]
NxIREE.Backend(metal://000000010000040f)
[
[
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
...
],
...
]
>
Softmax 関数をコンパイルします
第2引数には入力テンプレートを指定します
compiled_softmax = Nx.Defn.compile(softmax, [Nx.template({1000, 1000, 5}, :f32)])
コンパイル済 Softmax 関数を実行します
compiled_softmax.(iree_input)
実行結果
#Nx.Tensor<
f32[1000][1000][5]
NxIREE.Backend(metal://000000010000040f)
[
[
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
[0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
...
],
...
]
>
コンパイル前と同じ結果になりました
速度比較
Nx.BinaryBackend と EXLA.Backend の行列を用意し、速度比較します
binary_input =
{1000, 1000, 5}
|> Nx.iota(type: :f32, backend: Nx.BinaryBackend)
|> Nx.divide(1024 * 1024)
exla_input =
{1000, 1000, 5}
|> Nx.iota(type: :f32, backend: EXLA.Backend)
|> Nx.divide(1024 * 1024)
Benchee.run(%{
"nx" => fn -> softmax.(binary_input) end,
"exla" => fn -> softmax.(exla_input) end,
"nx_iree" => fn -> compiled_softmax.(iree_input) end
})
標準出力
Name ips average deviation median 99th %
exla 47.79 20.92 ms ±26.70% 20.18 ms 46.99 ms
nx_iree 34.74 28.78 ms ±15.92% 27.60 ms 52.80 ms
nx 0.192 5205.85 ms ±0.00% 5205.85 ms 5205.85 ms
Comparison:
exla 47.79
nx_iree 34.74 - 1.38x slower +7.86 ms
nx 0.192 - 248.80x slower +5184.93 ms
NxIREE が少し EXLA より遅くなりました(原因は不明)が、 Nx.BinaryBackend よりは圧倒的に速くなっています
まとめ
Nx のバックエンドとして NxIREE が使えるようになりました
Metal の威力が十分に発揮できれば、 macOS での機械学習に期待できますね