本記事はElixirで機械学習/ディープラーニングができるようになるnumpy likeなライブラリ Nxを使って
「ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装」
をElixirで書いていこうという記事になります。
今回は前回の記事をgoogle XLAをbackendで使うExlaで高速化していきます
準備編
exla setup
1章 pythonの基本 -> とばします
2章 パーセプトロン -> とばします
3章 ニューラルネットワーク
with exla
4章 ニューラルネットワークの学習
5章 誤差逆伝播法
Nx.Defn.Kernel.grad
6章 学習に関するテクニック -> とばします
7章 畳み込みニューラルネットワーク
install exla
※2022/04/14 追記
以下を追加してdeps.getを実行してください
def deps do
[
{:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
{:nx, "~> 0.1", github: "elixir-nx/nx", sparse: "nx", override: true}
]
end
現在はコンパイル済みバイナリを提供するようなったので、以下の手順は不要です
https://github.com/elixir-nx/nx/tree/main/exla#common-installation-issues
documentにかいてあるとおりに
コンパイルに使うbazelをインストールして
asdf plugin-add bazel
asdf install bazel 3.1.0
asdf global bazel 3.1.0
mix.exsにexlaとnxをいれて
mix deps.get
mix deps.compile
でインストールします
origin/master' did not match any file(s) known to git
とエラーがでたら branch: "main"オプションを追加してください
xlaコンパイルに30分以上かかりますので気長に待ちましょう
def deps do
[
{:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
{:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}
]
end
exla 有効化
exlaはdefnで定義した関数をxlaで実行します
defn内でMapは使えないので注意が必要です
defmodule MnistNx do
import Nx.Defn
@defn_compiler {EXLA, max_float_type: {:f, 64}} # これ追加
defn predict(x, {w1,w2,w3,b1,b2,b3}) do
x
|> Nx.dot(w1)
|> Nx.add(b1)
|> sigmoid()
|> Nx.dot(w2)
|> Nx.add(b2)
|> sigmoid()
|> Nx.dot(w3)
|> Nx.add(b3)
|> softmax()
end
defn sigmoid(tensor) do
1 / (1 + Nx.exp(-tensor))
end
defn softmax(tensor) do
tensor = Nx.add(tensor, -Nx.reduce_max(tensor))
tensor
|> Nx.exp()
|> Nx.divide( tensor |> Nx.exp() |> Nx.sum())
end
def acc(x,t,network) do
{row, _} = Nx.shape(x)
Enum.to_list(0..(row - 1))
|> Nx.tensor()
|> Nx.map(fn i ->
predict(x[i],network)
|> Nx.argmax()
|> Nx.equal(t[i])
end)
|> Nx.sum
|> Nx.divide(row)
end
def acc_batch(x,t,network) do
x_b = x |> Nx.to_batched_list(100)
t_b = t |> Nx.to_batched_list(100)
{row, _} = x |> Nx.shape()
batch = Enum.count(x_b)
Enum.to_list(0..(batch-1))
|> Nx.tensor()
|> Nx.map(fn i ->
predict(Enum.at(x_b,Nx.to_number(i)),network)
|> Nx.argmax(axis: 1)
|> Nx.equal(Enum.at(t_b,Nx.to_number(i)))
|> Nx.sum()
end)
|> Nx.sum()
|> Nx.divide(row)
end
end
exlaのオプションはベンチ用のコードを参考にしましょう
https://github.com/elixir-nx/nx/blob/main/exla/bench/softmax.exs
ベンチ
bench code
defmodule NxNn do
def get_data() do
x_test = Dataset.test_image() |> Nx.tensor() |> (& Nx.divide(&1, Nx.reduce_max(&1))).()
t_test = Dataset.test_label() |> Nx.tensor()
{x_test, t_test}
end
def init_network() do
{w1,w2,w3,b1,b2,b3} = PklLoad.load("pkl/sample_weight.pkl")
{
Nx.tensor(w1),
Nx.tensor(w2),
Nx.tensor(w3),
Nx.tensor(b1),
Nx.tensor(b2),
Nx.tensor(b3)
}
end
def acc do
{x, t} = get_data()
IO.puts("Load Data")
Benchee.run(%{
"exla cpu" => fn -> MnistNx.acc(x, t, init_network()) end,
"exla cpu batch" => fn -> MnistNx.acc_batch(x,t,init_network()) end,
"nx" => fn -> Mnist.acc(x, t, init_network()) end,
"nx batch" => fn -> Mnist.acc(x,t,init_network()) end
})
end
end
Load Data
Operating System: macOS
CPU Information: Intel(R) Core(TM) i7-7660U CPU @ 2.50GHz
Number of Available Cores: 4
Available memory: 16 GB
Elixir 1.11.3
Erlang 23.2.5
Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 28 s
Benchmarking exla cpu...
Benchmarking exla cpu batch...
Benchmarking nx...
Benchmarking nx batch...
Name ips average deviation median 99th %
exla cpu batch 1.34 0.0125 min ±4.21% 0.0123 min 0.0132 min
exla cpu 0.25 0.0668 min ±4.29% 0.0668 min 0.0688 min
nx 0.00294 5.67 min ±0.00% 5.67 min 5.67 min
nx batch 0.00271 6.14 min ±0.00% 6.14 min 6.14 min
Comparison:
exla cpu batch 1.34
exla cpu 0.25 - 5.35x slower +0.0543 min
nx 0.00294 - 454.65x slower +5.66 min
nx batch 0.00271 - 492.39x slower +6.13 min
backendをXLAをCPUモードにしただけで劇的に速度が上がりました!
GPUを使えばもっと早くなるので今後が楽しみです
全体のコードはこちらになります
https://github.com/thehaigo/nx_sample
https://github.com/thehaigo/nx_dl
参考ページ
https://github.com/elixir-nx/nx/tree/main/nx
https://github.com/elixir-nx/nx/tree/main/exla
https://github.com/elixir-nx/nx/blob/main/exla/bench/softmax.exs