3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Nx を NxIREE.Backend で動かす【NxIREE 0.0.1 リリース】

Last updated at Posted at 2024-10-10

はじめに

以前の記事でリリース前の NxIREE を使ってみました

その後、 2024/10/04 に Hex へバージョン 0.0.1 がリリースされたため、改めてリリース版を使用します

また、 Nx のバックエンドとして動かせるようになっているので、 NxIREE.Backend で Nx を動かしてみます

今回も実行環境には Livebook を使用します

行列演算の実行

実装したノートブックはこちら

セットアップ

Hex にリリースされたため、 GitHub リポジトリーを指定せずにインストール可能です

Mix.install([
  {:nx_iree, "~> 0.0"}
])

デバイスの取得

ドライバーの一覧を取得します

NxIREE.list_drivers()
|> elem(1)

実行結果

%{
  "local-sync" => "Local execution using a lightweight inline synchronous queue",
  "local-task" => "Local execution using the IREE multithreading task system",
  "metal" => "Apple Metal"
}

metal を指定してデバイスを取得します

dev =
  NxIREE.list_devices("metal")
  |> elem(1)
  |> hd()

実行結果

%NxIREE.Device{
  ref: #Reference<0.3868306615.1943666690.245036>,
  driver_name: "metal",
  kind: :io,
  id: 4479806464,
  uri: "metal://000000010000040f",
  compiler_target_backend: "metal-spirv"
}

中間表現のコンパイル

MLIR によって記述された行列演算処理を用意します

mlir_module = """
module {
  func.func public @main(%arg0: tensor<4xf32>, %arg1: tensor<4xi32>) -> tensor<4xf32> {
    %0 = stablehlo.cosine %arg0 : tensor<4xf32>
    %1 = stablehlo.convert %arg1 : (tensor<4xi32>) -> tensor<4xf32>
    %2 = stablehlo.sine %1 : tensor<4xf32>
    %3 = stablehlo.add %0, %2 : tensor<4xf32>
    return %3 : tensor<4xf32>
  }
}
"""

Metal 用のコンパイルフラグを定義します

flags = [
  "--iree-hal-target-backends=metal-spirv",
  "--iree-input-type=stablehlo_xla",
  "--iree-execution-model=async-internal"
]

コンパイルを実行します

module = NxIREE.compile(mlir_module, flags, output_container: Nx.template({4}, :f32))

リリース版では output_container として、出力テンプレートを指定するようになりました

行列演算の実行

コンパイルしたモジュールを NxIREE.call で実行します

arg0 = Nx.tensor([1.0, 2.0, 3.0, 4.0])
arg1 = Nx.tensor([1, -1, 1, -1])

NxIREE.call(module, [arg0, arg1], device: dev)

実行結果

{:ok,
 #Nx.Tensor<
   f32[4]
   NxIREE.Backend(metal://000000010000040f)
   [1.3817732334136963, -1.257617712020874, -0.1485215425491333, -1.4951145648956299]
 >}

Elixir コードからの変換

新しいノートブックで Elixir コードから中間表現への変換を実行します

実装したノートブックはこちら

セットアップ

Mix.install([
  {:exla, "~> 0.9"},
  {:nx, "~> 0.9"},
  {:nx_iree, "~> 0.0"},
  {:kino, "~> 0.14"},
  {:benchee, "~> 1.3"},
  {:statistics, "~> 0.6"}
])

Softmax 関数の変換

Softmax 関数を中間表現に変換します

NxIREE の強みが出るように入力サイズを大きくしています

softmax = fn tensor ->  
  Nx.divide(
    Nx.exp(tensor),
    Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
  )
end

input =
  {1000, 1000, 5}
  |> Nx.iota(type: :f32, backend: Nx.BinaryBackend)
  |> Nx.divide(1024 * 1024)

args = [input]

%{
  mlir_module: mlir_module,
  output_container: output_container
} = EXLA.to_mlir_module(softmax, args)

Kino.Text.new(mlir_module)

実行結果

module {
  func.func public @main(%arg0: tensor<1000x1000x5xf32>) -> tensor<1000x1000x5xf32> {
    %0 = stablehlo.exponential %arg0 : tensor<1000x1000x5xf32>
    %1 = stablehlo.exponential %arg0 : tensor<1000x1000x5xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %2 = stablehlo.reduce(%1 init: %cst) across dimensions = [2] : (tensor<1000x1000x5xf32>, tensor<f32>) -> tensor<1000x1000xf32>
     reducer(%arg1: tensor<f32>, %arg2: tensor<f32>)  {
      %6 = stablehlo.add %arg2, %arg1 : tensor<f32>
      stablehlo.return %6 : tensor<f32>
    }
    %3 = stablehlo.reshape %2 : (tensor<1000x1000xf32>) -> tensor<1000x1000x1xf32>
    %4 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<1000x1000x1xf32>) -> tensor<1000x1000x5xf32>
    %5 = stablehlo.divide %0, %4 : tensor<1000x1000x5xf32>
    return %5 : tensor<1000x1000x5xf32>
  }
}

リリース前は EXLA.to_mlir_module の出力がそのまま中間表現でしたが、リリースバージョンでは以下の形式になっています

%{
  mlir_module: <中間表現>,
  output_container: <出力テンプレート>,
  used_inputs: <変換中に使用した入力>
}

中間表現のコンパイル

変換した関数をコンパイルします

コンパイル時に変換結果として取得した output_container を指定します

dev =
  NxIREE.list_devices("metal")
  |> elem(1)
  |> hd()

flags = [
  "--iree-hal-target-backends=metal-spirv",
  "--iree-input-type=stablehlo_xla",
  "--iree-execution-model=async-internal"
]

module = NxIREE.compile(mlir_module, flags, output_container: output_container)

コンパイルした関数を呼び出します

NxIREE.call(module, args, device: dev)

実行結果

{:ok,
 #Nx.Tensor<
   f32[4]
   NxIREE.Backend(metal://000000010000040f)
   [0.2540842592716217, 0.6906726360321045, 0.034386567771434784, 0.020856507122516632]
 >}

速度比較

Nx.BinaryBackendEXLA.Backend 、そしてコンパイル済中間表現で速度比較してみます

exla_input = Nx.backend_transfer(input, EXLA.Backend)
Benchee.run(%{
  "nx" => fn -> softmax.(input) end,
  "exla" => fn -> softmax.(exla_input) end,
  "nx_iree" => fn -> NxIREE.call(module, [input], device: dev) end
})

標準出力

Name              ips        average  deviation         median         99th %
nx_iree         76.37       13.09 ms    ±29.84%       12.38 ms       34.71 ms
exla            49.63       20.15 ms     ±8.81%       19.93 ms       31.95 ms
nx              0.194     5144.19 ms     ±0.00%     5144.19 ms     5144.19 ms

Comparison: 
nx_iree         76.37
exla            49.63 - 1.54x slower +7.06 ms
nx              0.194 - 392.84x slower +5131.09 ms

1000x1000x5 の入力サイズでは NxIREE が最速になり、 Nx.BinaryBackend と比べるとおよそ 400 倍の速さになっています

NxIREE.Backend

Nx のバックエンドとして NxIREE を使用します

実装したノートブックはこちら

セットアップ

NxIREE と速度比較用の Benchee をインストールします

Mix.install([
  {:nx_iree, "~> 0.0"},
  {:benchee, "~> 1.3"}
])

デバイスの取得

Metal デバイスを取得します

dev =
  NxIREE.list_devices("metal")
  |> elem(1)
  |> hd()

Softmax 関数のコンパイル

Nx.Defn.default_options で関数コンパイル時の設定を指定します

flags = [
  "--iree-hal-target-backends=metal-spirv",
  "--iree-input-type=stablehlo_xla",
  "--iree-execution-model=async-internal"
]

Nx.Defn.default_options(
  compiler: NxIREE.Compiler,
  iree_compiler_flags: flags,
  iree_runtime_options: [device: dev]
)

Softmax 関数を定義します

softmax = fn tensor ->  
  Nx.divide(
    Nx.exp(tensor),
    Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
  )
end

入力として、 1000x1000x5 の行列を NxIREE.Backend で作成します

iree_input =
  {1000, 1000, 5}
  |> Nx.iota(type: :f32, backend: NxIREE.Backend)
  |> Nx.divide(1024 * 1024)

そのまま Softmax 関数を実行してみます

softmax.(iree_input)

実行結果

#Nx.Tensor<
  f32[1000][1000][5]
  NxIREE.Backend(metal://000000010000040f)
  [
    [
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      ...
    ],
    ...
  ]
>

Softmax 関数をコンパイルします

第2引数には入力テンプレートを指定します

compiled_softmax = Nx.Defn.compile(softmax, [Nx.template({1000, 1000, 5}, :f32)])

コンパイル済 Softmax 関数を実行します

compiled_softmax.(iree_input)

実行結果

#Nx.Tensor<
  f32[1000][1000][5]
  NxIREE.Backend(metal://000000010000040f)
  [
    [
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      [0.1999996155500412, 0.19999980926513672, 0.20000000298023224, 0.20000018179416656, 0.20000037550926208],
      ...
    ],
    ...
  ]
>

コンパイル前と同じ結果になりました

速度比較

Nx.BinaryBackendEXLA.Backend の行列を用意し、速度比較します

binary_input =
  {1000, 1000, 5}
  |> Nx.iota(type: :f32, backend: Nx.BinaryBackend)
  |> Nx.divide(1024 * 1024)

exla_input =
  {1000, 1000, 5}
  |> Nx.iota(type: :f32, backend: EXLA.Backend)
  |> Nx.divide(1024 * 1024)
Benchee.run(%{
  "nx" => fn -> softmax.(binary_input) end,
  "exla" => fn -> softmax.(exla_input) end,
  "nx_iree" => fn -> compiled_softmax.(iree_input) end
})

標準出力

Name              ips        average  deviation         median         99th %
exla            47.79       20.92 ms    ±26.70%       20.18 ms       46.99 ms
nx_iree         34.74       28.78 ms    ±15.92%       27.60 ms       52.80 ms
nx              0.192     5205.85 ms     ±0.00%     5205.85 ms     5205.85 ms

Comparison: 
exla            47.79
nx_iree         34.74 - 1.38x slower +7.86 ms
nx              0.192 - 248.80x slower +5184.93 ms

NxIREE が少し EXLA より遅くなりました(原因は不明)が、 Nx.BinaryBackend よりは圧倒的に速くなっています

まとめ

Nx のバックエンドとして NxIREE が使えるようになりました

Metal の威力が十分に発揮できれば、 macOS での機械学習に期待できますね

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?