21
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

はじめに

@zacky1972 さんが、 NxIREE なるものについて調べてほしいと零していたので、動かしてみました

実装したノートブックはこちら

IREE とは

IREE = Intermediate Representation Execution Environment = 中間表現実行環境

 イーリーと読みます

機械学習モデルを CPU, GPU, TPU など、様々なハードウェア上で実行できる環境を提供します

XLA が TensorFlow 専用なのに対し、 IREE は PyTorch など他のフレームワークにも対応しています

また、 IREE は CPU, GPT, TPU 以外のモバイルデバイスにも対応しています

NxIREE の目的

NxIREE の README から引用

Companion library to EXLA, providing bindings for the IREE runtime for MLIR.

和訳

EXLAの補助ライブラリであり、MLIRのIREEランタイムのバインディングを提供します

MLIR = Multi-Level Intermediate Representation = 多段階中間表現

MLIR は以下のように利用されます

  • Python などの高レベルコード(人間が書いたコード)を中間表現に変換する
  • 効率的に実行できるよう最適化する
  • 低レベルコード(機械が実行するコード)に変換する

EXLA を様々な環境で実行できるようにするためのモジュールということのようです

実行環境

Apple Silicon の Mac にしか対応していなさそうなので、 M2Mac で実行しました

  • MacBook Pro 13 インチ M2 2022
  • macOS Sonoma 14.5
  • Erlang 27.0.1
  • Elixir 1.17.2-otp-27

Erlang と Elixir は asdf でインストール

iex での実行

リポジトリーにある "run.exs" を動かしてみます

NxIREE.list_drivers() |> IO.inspect(label: "drivers")

{:ok, [dev | _]} = NxIREE.list_devices("metal") |> IO.inspect()

# Obtained by using EXLA.to_mlir_module(fn a, b -> Nx.add(Nx.cos(a), Nx.sin(b)) end, [Nx.template({4}, :f32), Nx.template({4}, :s64)])
mlir_module = """
module {
  func.func public @main(%arg0: tensor<4xf32>, %arg1: tensor<4xi64>) -> tensor<4xf32> {
    %0 = stablehlo.cosine %arg0 : tensor<4xf32>
    %1 = stablehlo.convert %arg1 : (tensor<4xi64>) -> tensor<4xf32>
    %2 = stablehlo.sine %1 : tensor<4xf32>
    %3 = stablehlo.add %0, %2 : tensor<4xf32>
    return %3 : tensor<4xf32>
  }
}
"""

# flags = ["--iree-hal-target-backends=cuda", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
flags = ["--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]

%NxIREE.Module{} = module = NxIREE.compile(mlir_module, flags)

arg0 = Nx.tensor([1.0, 2.0, 3.0, 4.0])
arg1 = Nx.tensor([1, -1, 1, -1])

IO.gets("Press enter to continue - #{System.pid()}")
{:ok, [result]} = NxIREE.call(module, [arg0, arg1], device: dev) |> IO.inspect()

IO.inspect(result, limit: 4)


IO.gets("Press enter to finish")

リポジトリーのクローン

git clone https://github.com/elixir-nx/nx_iree
cd nx_iree

依存モジュールのインストール

mix deps.get

コンパイル

コンパイルには Ninja が必要です

未インストールの場合はインストールします

brew install ninja

コンパイルを実行します

mix compile

リネーム

"run.exs" を ".iex.exs" にリネームします

".iex.exs" は iex -S mix 実行時に自動実行されます

処理の実行

$ iex -S mix

実行結果

drivers: {:ok,
 %{
   "local-sync" => "Local execution using a lightweight inline synchronous queue",
   "local-task" => "Local execution using the IREE multithreading task system",
   "metal" => "Apple Metal"
 }}
{:ok, ["metal://default", "metal://00000001000003e9"]}
Press enter to continue - 52492
{:ok,
 [
   #Nx.Tensor<
     f32[4]
     NxIREE.Tensor(metal://default)
     [1.3817732334136963, -1.257617712020874, -0.1485215425491333, -1.4951145648956299]
   >
 ]}
#Nx.Tensor<
  f32[4]
  NxIREE.Tensor(metal://default)
  [1.3817732334136963, -1.257617712020874, -0.1485215425491333, -1.4951145648956299]
>
Press enter to finish

とりあえず実行できました

各コードの意味は Livebook で動かしながら確認してみましょう

Livebook での実行

"run.exs" の内容を Livebook で個別に実行し、処理の内容を理解していきます

モジュールのインストール

まだ Hex に公開されていないので、 GitHub から NxIREE をインストールします

Mix.install([
  {:nx_iree, "~> 0.1", git: "https://github.com/elixir-nx/nx_iree"}
])

デバイスの取得

デバイスの一覧を取得します

NxIREE.list_drivers()
|> elem(1)

実行結果

%{
  "local-sync" => "Local execution using a lightweight inline synchronous queue",
  "local-task" => "Local execution using the IREE multithreading task system",
  "metal" => "Apple Metal"
}

M2Mac なので Metal が使えるようです

Metal は iPhone や MacBook など、 Apple 製品で GPU を操作するための API です

IREE によって Metal を動かすための低レベルコードを生成できるようです

実際に使用するデバイスを取得します

dev =
  NxIREE.list_devices("metal")
  |> elem(1)
  |> hd()

実行結果

"metal://default"

中間表現のコンパイル

MLIR によって記述された行列演算処理を用意します

コメントにある通り、このコードは EXLA.to_mlir_module で Elixir の高レベルコードを中間表現として取得したものです

# Obtained by using EXLA.to_mlir_module(fn a, b -> Nx.add(Nx.cos(a), Nx.sin(b)) end, [Nx.template({4}, :f32), Nx.template({4}, :s64)])
mlir_module = """
module {
  func.func public @main(%arg0: tensor<4xf32>, %arg1: tensor<4xi64>) -> tensor<4xf32> {
    %0 = stablehlo.cosine %arg0 : tensor<4xf32>
    %1 = stablehlo.convert %arg1 : (tensor<4xi64>) -> tensor<4xf32>
    %2 = stablehlo.sine %1 : tensor<4xf32>
    %3 = stablehlo.add %0, %2 : tensor<4xf32>
    return %3 : tensor<4xf32>
  }
}
"""

中間表現の中に出てくる stablehlo.cosine などの関数についてはこちらで確認できます

StableHLO という機械学習用の高レベル操作を集めたもののようです

中間表現を Metal 用にコンパイルします

コンパイル時のフラグ設定

flags = [
  "--iree-hal-target-backends=metal-spirv",
  "--iree-input-type=stablehlo_xla",
  "--iree-execution-model=async-internal"
]

コンパイルの実行

module = NxIREE.compile(mlir_module, flags)

演算処理の実行

コンパイルした処理を実行します

arg0 = Nx.tensor([1.0, 2.0, 3.0, 4.0])
arg1 = Nx.tensor([1, -1, 1, -1])

NxIREE.call(module, [arg0, arg1], device: dev)

実行結果

{:ok,
 [
   #Nx.Tensor<
     f32[4]
     NxIREE.Tensor(metal://default)
     [1.3817732334136963, -1.257617712020874, -0.1485215425491333, -1.4951145648956299]
   >
 ]}

Nx での演算結果と比較

Nx で演算した場合と同じになるか確認します

nx_function = fn a, b ->
  Nx.add(Nx.cos(a), Nx.sin(b))
end

nx_function.(arg0, arg1)
#Nx.Tensor<
  f32[4]
  [1.3817732334136963, -1.2576178312301636, -0.1485215425491333, -1.4951145648956299]
>

微妙に誤差はありますが、ほぼ同じ結果になっています

Nx の演算を Metal 上で実行できたようです

Elixir コードからの変換

EXLA の最新版には EXLA.to_mlir_module が含まれているため、これを利用して Elixir コードから中間表現が生成できます

モジュールのインストール

EXLA と Nx も GitHub から取得し、速度比較のために Benchee 、 ランダム数値(正規分布)生成のために Statistics をインストールします

Mix.install([
  {:exla, "~> 0.7", github: "elixir-nx/nx", sparse: "exla"},
  {:nx, "~> 0.7", github: "elixir-nx/nx", sparse: "nx", override: true},
  {:nx_iree, "~> 0.1", git: "https://github.com/elixir-nx/nx_iree"},
  {:kino, "~> 0.13"},
  {:benchee, "~> 1.3"},
  {:statistics, "~> 0.6"}
])

中間表現の取得

Softmax 関数の中間表現を生成してみます

softmax = fn tensor ->  
  Nx.divide(
    Nx.exp(tensor),
    Nx.sum(Nx.exp(tensor), axes: [-1], keep_axes: true)
  )
end

input = Nx.tensor([1.0, 2.0, -1.0, -1.5])
args = [input]

mlir_module = EXLA.to_mlir_module(softmax, args)

Kino.Text.new(mlir_module)

実行結果

module {
  func.func public @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
    %0 = stablehlo.exponential %arg0 : tensor<4xf32>
    %1 = stablehlo.exponential %arg0 : tensor<4xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %2 = stablehlo.reduce(%1 init: %cst) across dimensions = [0] : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
     reducer(%arg1: tensor<f32>, %arg2: tensor<f32>)  {
      %6 = stablehlo.add %arg2, %arg1 : tensor<f32>
      stablehlo.return %6 : tensor<f32>
    }
    %3 = stablehlo.reshape %2 : (tensor<f32>) -> tensor<1xf32>
    %4 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<1xf32>) -> tensor<4xf32>
    %5 = stablehlo.divide %0, %4 : tensor<4xf32>
    return %5 : tensor<4xf32>
  }
}

何となく、ちゃんとできていそうな気がします

コンパイル

デバイスの取得

dev =
  NxIREE.list_devices("metal")
  |> elem(1)
  |> hd()

フラグ設定

flags = [
  "--iree-hal-target-backends=metal-spirv",
  "--iree-input-type=stablehlo_xla",
  "--iree-execution-model=async-internal"
]
module = NxIREE.compile(mlir_module, flags)

演算処理の実行

NxIREE.call(module, args, device: dev)

実行結果

{:ok,
 [
   #Nx.Tensor<
     f32[4]
     NxIREE.Tensor(metal://default)
     [0.2540842592716217, 0.6906726360321045, 0.034386567771434784, 0.020856507122516632]
   >
 ]}

直接 Softmax を実行してみます

softmax.(input)

実行結果

#Nx.Tensor<
  f32[4]
  [0.2540842592716217, 0.6906726956367493, 0.034386567771434784, 0.02085650898516178]
>

誤差はあるものの、ほぼ同じ結果です

速度比較

Nx.BinaryBackend の場合と EXLA.Backend の場合、 NxIREE の場合で速度比較してみます

exla_input = Nx.backend_transfer(input, EXLA.Backend)
Benchee.run(%{
  "nx" => fn -> softmax.(input) end,
  "exla" => fn -> softmax.(exla_input) end,
  "nx_iree" => fn -> NxIREE.call(module, [input], device: dev) end
})

実行結果

Name              ips        average  deviation         median         99th %
nx           265.49 K        3.77 μs   ±257.11%        3.46 μs        6.13 μs
exla          14.52 K       68.88 μs    ±43.40%       66.37 μs      104.92 μs
nx_iree        1.79 K      557.50 μs    ±14.71%      550.14 μs      860.11 μs

Comparison: 
nx           265.49 K
exla          14.52 K - 18.29x slower +65.11 μs
nx_iree        1.79 K - 148.01x slower +553.74 μs

予測に反して Nx.BinaryBackend が速く、 NxIREE が遅くなってしまいました

恐らく行列が小さすぎることが問題なので、行列を大きくしてみます

長さ 100,000 のテンソルを用意しました

input =
  1..100000
  |> Enum.map(fn _ -> Statistics.Distributions.Normal.rand() end)
  |> Nx.tensor()

args = [input]

mlir_module = EXLA.to_mlir_module(softmax, args)

Kino.Text.new(mlir_module)

実行結果

module {
  func.func public @main(%arg0: tensor<100000xf32>) -> tensor<100000xf32> {
    %0 = stablehlo.exponential %arg0 : tensor<100000xf32>
    %1 = stablehlo.exponential %arg0 : tensor<100000xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %2 = stablehlo.reduce(%1 init: %cst) across dimensions = [0] : (tensor<100000xf32>, tensor<f32>) -> tensor<f32>
     reducer(%arg1: tensor<f32>, %arg2: tensor<f32>)  {
      %6 = stablehlo.add %arg2, %arg1 : tensor<f32>
      stablehlo.return %6 : tensor<f32>
    }
    %3 = stablehlo.reshape %2 : (tensor<f32>) -> tensor<1xf32>
    %4 = stablehlo.broadcast_in_dim %3, dims = [0] : (tensor<1xf32>) -> tensor<100000xf32>
    %5 = stablehlo.divide %0, %4 : tensor<100000xf32>
    return %5 : tensor<100000xf32>
  }
}

コンパイルします

module = NxIREE.compile(mlir_module, flags)

改めて速度比較してみます

Benchee.run(%{
  "nx" => fn -> softmax.(input) end,
  "exla" => fn -> softmax.(exla_input) end,
  "nx_iree" => fn -> NxIREE.call(module, [input], device: dev) end
})

実行結果

Name              ips        average  deviation         median         99th %
exla          14.29 K       70.00 μs    ±26.28%       67.71 μs      101.50 μs
nx_iree        1.13 K      886.81 μs    ±22.57%      854.91 μs     1675.55 μs
nx           0.0232 K    43122.48 μs     ±3.44%    42999.11 μs    52094.74 μs

Comparison: 
exla          14.29 K
nx_iree        1.13 K - 12.67x slower +816.81 μs
nx           0.0232 K - 616.08x slower +43052.48 μs

EXLA が最も速いという結果になりましたが、 BiraryBackend と比べれば IREE もかなり高速です

まとめ

NxIREE を使うことで、様々な環境で高速に演算処理が実行できそうです

まだ未リリースなので、これからの発展に期待ですね

21
4
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?