Nx の Defn
の利点
利点その1: 数式をわかりやすく定義できる
Nxの Defn
を使うと,数式を定義できます.
公式ドキュメントに上がっている例です.
defmodule MyModule do
import Nx.Defn
defn softmax(t) do
Nx.exp(t) / Nx.sum(Nx.exp(t))
end
end
ここでは /
を含む式が書かれていますが,これは次のような式と等価です.
Nx.divide(Nx.exp(t), Nx.sum(Nx.exp(t)))
どうでしょう? 前者の方が読みやすいですよね?
Defn
では,次に書かれている記号や関数を使うことができます.
利点その2: Nxバックエンドで高速化できる
Nx の Defn
にはもう1つ重要な利点があります.EXLAやTorchXなどのNxバックエンドを利用して,数式全体をAOT/JITコンパイルすることで高速化できます.AOTコンパイルは Ahead of time compile ということで,事前にコンパイルしてネイティブコードにします.JITコンパイルは Just in time compile ということで,実行中にコンパイルしてネイティブコードにします.
次のようなベンチマークコードで検証しましょう.
Mix.install(
[
{:nx, "~> 0.4"},
{:exla, "~> 0.4"},
{:benchee, "~> 1.1"}
]
)
defmodule Softmax do
import Nx.Defn
defn softmax_defn(n), do: Nx.exp(n) / Nx.sum(Nx.exp(n))
def softmax_def(n) do
Nx.divide(Nx.exp(n), Nx.sum(Nx.exp(n)))
end
end
host_jit = EXLA.jit(&Softmax.softmax_defn/1)
key = Nx.Random.key(12)
size = [1_000, 1_000_000]
inputs =
size
|> Enum.map(fn size ->
{"#{size}", Nx.Random.uniform(key, shape: {size}, type: :f32) |> elem(0)}
end)
|> Map.new()
Benchee.run(
%{
"Nx def" => fn input -> Softmax.softmax_def(input) end,
"Nx defn" => fn input -> Softmax.softmax_defn(input) end,
"EXLA defn" => fn input -> host_jit.(input) end
},
inputs: inputs,
time: 10,
memory_time: 2
)
M2 Macbook Airで検証した結果が次のとおりです.EXLAはCPUでのみ動作させています.
% elixir defn_bench.exs
Operating System: macOS
CPU Information: Apple M2
Number of Available Cores: 8
Available memory: 24 GB
Elixir 1.14.2
Erlang 25.2
Benchmark suite executing with the following configuration:
warmup: 2 s
time: 10 s
memory time: 2 s
reduction time: 0 ns
parallel: 1
inputs: 1000, 1000000
Estimated total run time: 1.40 min
Benchmarking EXLA defn with input 1000 ...
02:58:45.731 [info] TfrtCpuClient created.
Benchmarking EXLA defn with input 1000000 ...
Benchmarking Nx def with input 1000 ...
Benchmarking Nx def with input 1000000 ...
Benchmarking Nx defn with input 1000 ...
Benchmarking Nx defn with input 1000000 ...
##### With input 1000 #####
Name ips average deviation median 99th %
EXLA defn 56.16 K 17.81 μs ±267.80% 16.33 μs 39.92 μs
Nx def 3.56 K 280.76 μs ±13.76% 278.54 μs 376.08 μs
Nx defn 2.51 K 398.98 μs ±23.27% 384.67 μs 605.54 μs
Comparison:
EXLA defn 56.16 K
Nx def 3.56 K - 15.77x slower +262.95 μs
Nx defn 2.51 K - 22.41x slower +381.18 μs
Memory usage statistics:
Name Memory usage
EXLA defn 5.07 KB
Nx def 595.81 KB - 117.51x memory usage +590.74 KB
Nx defn 603.92 KB - 119.11x memory usage +598.85 KB
**All measurements for memory usage were the same**
##### With input 1000000 #####
Name ips average deviation median 99th %
EXLA defn 1435.25 0.70 ms ±56.19% 0.60 ms 2.60 ms
Nx defn 4.81 208.02 ms ±3.32% 210.58 ms 219.95 ms
Nx def 4.38 228.25 ms ±3.57% 226.74 ms 237.45 ms
Comparison:
EXLA defn 1435.25
Nx defn 4.81 - 298.57x slower +207.33 ms
Nx def 4.38 - 327.60x slower +227.55 ms
Memory usage statistics:
Name Memory usage
EXLA defn 0.00495 MB
Nx defn 579.70 MB - 117076.03x memory usage +579.69 MB
Nx def 579.64 MB - 117064.64x memory usage +579.64 MB
**All measurements for memory usage were the same**
Nx defn
はNx def
より若干高速ですね.それ以上に,EXLA defn
は圧倒的に高速で,要素数1,000の時には15.77〜22.41倍,要素数1,000,000の時には298.57〜327.60倍も高速です.
Defnで書けないNxのモジュールと関数
Nx.Defn.Compiler
モジュールの@allowed_modules
と@forbidden_ops
にまとめられています.
Nxバージョン0.4.1では次のとおりです.
# Modules allowed in defn
@allowed_modules [Nx, Nx.Constants, Nx.Defn, Nx.Defn.Kernel, Nx.LinAlg, Nx.Type]
# These operations do not have valid meaning for Nx.Defn.Expr
@forbidden_ops [:backend_copy, :backend_deallocate, :backend_transfer] ++
[:to_binary, :to_number, :to_flat_list, :to_heatmap, :to_batched] ++
[:from_numpy, :from_numpy_archive, :compatible?, :default_backend] ++
[:serialize, :deserialize]
次のモジュールを使うことができます.
-
Nx
: https://hexdocs.pm/nx/Nx.html -
Nx.Constants
: https://hexdocs.pm/nx/Nx.Constants.html -
Nx.Defn
: https://hexdocs.pm/nx/Nx.Defn.html -
Nx.Defn.Kernel
: https://hexdocs.pm/nx/Nx.Defn.Kernel.html -
Nx.LinAlg
: https://hexdocs.pm/nx/Nx.LinAlg.html -
Nx.Type
: https://hexdocs.pm/nx/Nx.Type.html
次の関数を使うことができません.
-
Nx.backend_copy/2
: https://hexdocs.pm/nx/Nx.html#backend_copy/2 -
Nx.backend_deallocate/1
: https://hexdocs.pm/nx/Nx.html#backend_deallocate/1 -
Nx.backend_transfer/2
: https://hexdocs.pm/nx/Nx.html#backend_transfer/2 -
Nx.to_binary/2
: https://hexdocs.pm/nx/Nx.html#to_binary/2 -
Nx.to_number/1
: https://hexdocs.pm/nx/Nx.html#to_number/1 -
Nx.to_flat_list/2
: https://hexdocs.pm/nx/Nx.html#to_flat_list/2 -
Nx.to_batched/3
: https://hexdocs.pm/nx/Nx.html#to_batched/3 -
Nx.from_numpy/1
: https://hexdocs.pm/nx/Nx.html#from_numpy/1 -
Nx.from_numpy_archive/1
: https://hexdocs.pm/nx/Nx.html#from_numpy_archive/1 -
Nx.compatible?/2
: https://hexdocs.pm/nx/Nx.html#compatible?/2 -
Nx.default_backend/0
: https://hexdocs.pm/nx/Nx.html#default_backend/0 -
Nx.default_backend/1
: https://hexdocs.pm/nx/Nx.html#default_backend/1 -
Nx.serialize/2
: https://hexdocs.pm/nx/Nx.html#serialize/2 -
Nx.deserialize/2
: https://hexdocs.pm/nx/Nx.html#deserialize/2