LoginSignup
15
2

More than 1 year has passed since last update.

Nx の Defn の利点,Defn で書けない Nx のモジュールと関数 (2022年版)

Last updated at Posted at 2022-12-23

Nx の Defn の利点

利点その1: 数式をわかりやすく定義できる

Nxの Defn を使うと,数式を定義できます.

公式ドキュメントに上がっている例です.

defmodule MyModule do
  import Nx.Defn

  defn softmax(t) do
    Nx.exp(t) / Nx.sum(Nx.exp(t))
  end
end

ここでは / を含む式が書かれていますが,これは次のような式と等価です.

Nx.divide(Nx.exp(t), Nx.sum(Nx.exp(t)))

どうでしょう? 前者の方が読みやすいですよね?

Defn では,次に書かれている記号や関数を使うことができます.

利点その2: Nxバックエンドで高速化できる

Nx の Defn にはもう1つ重要な利点があります.EXLAやTorchXなどのNxバックエンドを利用して,数式全体をAOT/JITコンパイルすることで高速化できます.AOTコンパイルは Ahead of time compile ということで,事前にコンパイルしてネイティブコードにします.JITコンパイルは Just in time compile ということで,実行中にコンパイルしてネイティブコードにします.

次のようなベンチマークコードで検証しましょう.

Mix.install(
  [
    {:nx, "~> 0.4"},
    {:exla, "~> 0.4"},
    {:benchee, "~> 1.1"}
  ]
)


defmodule Softmax do
  import Nx.Defn

  defn softmax_defn(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n))

  def softmax_def(n) do
    Nx.divide(Nx.exp(n), Nx.sum(Nx.exp(n)))
  end
end

host_jit = EXLA.jit(&Softmax.softmax_defn/1)

key = Nx.Random.key(12)
size = [1_000, 1_000_000]

inputs =
  size
  |> Enum.map(fn size ->
      {"#{size}", Nx.Random.uniform(key, shape: {size}, type: :f32) |> elem(0)}
  end)
  |> Map.new()

Benchee.run(
  %{
    "Nx def" => fn input -> Softmax.softmax_def(input) end,
    "Nx defn" => fn input -> Softmax.softmax_defn(input) end,
    "EXLA defn" => fn input -> host_jit.(input) end
  },
  inputs: inputs,
  time: 10,
  memory_time: 2
)

M2 Macbook Airで検証した結果が次のとおりです.EXLAはCPUでのみ動作させています.

% elixir defn_bench.exs
Operating System: macOS
CPU Information: Apple M2
Number of Available Cores: 8
Available memory: 24 GB
Elixir 1.14.2
Erlang 25.2

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 2 s
reduction time: 0 ns
parallel: 1
inputs: 1000, 1000000
Estimated total run time: 1.40 min

Benchmarking EXLA defn with input 1000 ...

02:58:45.731 [info] TfrtCpuClient created.
Benchmarking EXLA defn with input 1000000 ...
Benchmarking Nx def with input 1000 ...
Benchmarking Nx def with input 1000000 ...
Benchmarking Nx defn with input 1000 ...
Benchmarking Nx defn with input 1000000 ...

##### With input 1000 #####
Name                ips        average  deviation         median         99th %
EXLA defn       56.16 K       17.81 μs   ±267.80%       16.33 μs       39.92 μs
Nx def           3.56 K      280.76 μs    ±13.76%      278.54 μs      376.08 μs
Nx defn          2.51 K      398.98 μs    ±23.27%      384.67 μs      605.54 μs

Comparison: 
EXLA defn       56.16 K
Nx def           3.56 K - 15.77x slower +262.95 μs
Nx defn          2.51 K - 22.41x slower +381.18 μs

Memory usage statistics:

Name         Memory usage
EXLA defn         5.07 KB
Nx def          595.81 KB - 117.51x memory usage +590.74 KB
Nx defn         603.92 KB - 119.11x memory usage +598.85 KB

**All measurements for memory usage were the same**

##### With input 1000000 #####
Name                ips        average  deviation         median         99th %
EXLA defn       1435.25        0.70 ms    ±56.19%        0.60 ms        2.60 ms
Nx defn            4.81      208.02 ms     ±3.32%      210.58 ms      219.95 ms
Nx def             4.38      228.25 ms     ±3.57%      226.74 ms      237.45 ms

Comparison: 
EXLA defn       1435.25
Nx defn            4.81 - 298.57x slower +207.33 ms
Nx def             4.38 - 327.60x slower +227.55 ms

Memory usage statistics:

Name         Memory usage
EXLA defn      0.00495 MB
Nx defn         579.70 MB - 117076.03x memory usage +579.69 MB
Nx def          579.64 MB - 117064.64x memory usage +579.64 MB

**All measurements for memory usage were the same**

Nx defnNx defより若干高速ですね.それ以上に,EXLA defnは圧倒的に高速で,要素数1,000の時には15.77〜22.41倍,要素数1,000,000の時には298.57〜327.60倍も高速です.

Defnで書けないNxのモジュールと関数

Nx.Defn.Compilerモジュールの@allowed_modules@forbidden_opsにまとめられています.

Nxバージョン0.4.1では次のとおりです.

  # Modules allowed in defn
  @allowed_modules [Nx, Nx.Constants, Nx.Defn, Nx.Defn.Kernel, Nx.LinAlg, Nx.Type]

  # These operations do not have valid meaning for Nx.Defn.Expr
  @forbidden_ops [:backend_copy, :backend_deallocate, :backend_transfer] ++
                   [:to_binary, :to_number, :to_flat_list, :to_heatmap, :to_batched] ++
                   [:from_numpy, :from_numpy_archive, :compatible?, :default_backend] ++
                   [:serialize, :deserialize]

次のモジュールを使うことができます.

次の関数を使うことができません.

15
2
7

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
2