本記事はElixirで機械学習/ディープラーニングができるようになるnumpy likeなライブラリ Nxを使って
「ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装」
をElixirで書いていこうという記事になります。
今回は前回の記事をgoogle XLAをbackendで使うExlaで高速化していきます
準備編
exla setup
1章 pythonの基本 -> とばします
2章 パーセプトロン -> とばします
3章 ニューラルネットワーク
with exla
4章 ニューラルネットワークの学習
5章 誤差逆伝播法
Nx.Defn.Kernel.grad
6章 学習に関するテクニック -> とばします
7章 畳み込みニューラルネットワーク
install exla
※2022/04/14 追記
以下を追加してdeps.getを実行してください
def deps do
  [
    {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
    {:nx, "~> 0.1", github: "elixir-nx/nx", sparse: "nx", override: true}
  ]
end
現在はコンパイル済みバイナリを提供するようなったので、以下の手順は不要です
https://github.com/elixir-nx/nx/tree/main/exla#common-installation-issues
documentにかいてあるとおりに
コンパイルに使うbazelをインストールして
asdf plugin-add bazel
asdf install bazel 3.1.0
asdf global bazel 3.1.0
mix.exsにexlaとnxをいれて
mix deps.get
mix deps.compile
でインストールします
origin/master' did not match any file(s) known to git
とエラーがでたら branch: "main"オプションを追加してください
xlaコンパイルに30分以上かかりますので気長に待ちましょう
def deps do
  [
    {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
    {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}
  ]
end
exla 有効化
exlaはdefnで定義した関数をxlaで実行します
defn内でMapは使えないので注意が必要です
defmodule MnistNx do
  import Nx.Defn
  @defn_compiler {EXLA, max_float_type: {:f, 64}} # これ追加
  defn predict(x, {w1,w2,w3,b1,b2,b3}) do
    x
    |> Nx.dot(w1)
    |> Nx.add(b1)
    |> sigmoid()
    |> Nx.dot(w2)
    |> Nx.add(b2)
    |> sigmoid()
    |> Nx.dot(w3)
    |> Nx.add(b3)
    |> softmax()
  end
  defn sigmoid(tensor) do
    1 / (1 + Nx.exp(-tensor))
  end
  defn softmax(tensor) do
    tensor = Nx.add(tensor, -Nx.reduce_max(tensor))
    tensor
    |> Nx.exp()
    |> Nx.divide( tensor |> Nx.exp() |> Nx.sum())
  end
  def acc(x,t,network) do
    {row, _} = Nx.shape(x)
    Enum.to_list(0..(row - 1))
    |> Nx.tensor()
    |> Nx.map(fn i ->
      predict(x[i],network)
      |> Nx.argmax()
      |> Nx.equal(t[i])
    end)
    |> Nx.sum
    |> Nx.divide(row)
  end
  def acc_batch(x,t,network) do
    x_b = x |> Nx.to_batched_list(100)
    t_b = t |> Nx.to_batched_list(100)
    {row, _} = x |> Nx.shape()
    batch = Enum.count(x_b)
    Enum.to_list(0..(batch-1))
    |> Nx.tensor()
    |> Nx.map(fn i ->
      predict(Enum.at(x_b,Nx.to_number(i)),network)
      |> Nx.argmax(axis: 1)
      |> Nx.equal(Enum.at(t_b,Nx.to_number(i)))
      |> Nx.sum()
    end)
    |> Nx.sum()
    |> Nx.divide(row)
  end
end
exlaのオプションはベンチ用のコードを参考にしましょう
https://github.com/elixir-nx/nx/blob/main/exla/bench/softmax.exs
ベンチ
bench code
defmodule NxNn do
  def get_data() do
    x_test = Dataset.test_image() |> Nx.tensor() |> (& Nx.divide(&1, Nx.reduce_max(&1))).()
    t_test = Dataset.test_label() |> Nx.tensor()
    {x_test, t_test}
  end
  def init_network() do
    {w1,w2,w3,b1,b2,b3} = PklLoad.load("pkl/sample_weight.pkl")
    {
      Nx.tensor(w1),
      Nx.tensor(w2),
      Nx.tensor(w3),
      Nx.tensor(b1),
      Nx.tensor(b2),
      Nx.tensor(b3)
    }
  end
  def acc do
    {x, t} = get_data()
    IO.puts("Load Data")
    Benchee.run(%{
      "exla cpu" => fn -> MnistNx.acc(x, t, init_network()) end,
      "exla cpu batch" => fn -> MnistNx.acc_batch(x,t,init_network()) end,
      "nx" => fn -> Mnist.acc(x, t, init_network()) end,
      "nx batch" => fn ->  Mnist.acc(x,t,init_network()) end
    })
   end
end
Load Data
Operating System: macOS
CPU Information: Intel(R) Core(TM) i7-7660U CPU @ 2.50GHz
Number of Available Cores: 4
Available memory: 16 GB
Elixir 1.11.3
Erlang 23.2.5
Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 28 s
Benchmarking exla cpu...
Benchmarking exla cpu batch...
Benchmarking nx...
Benchmarking nx batch...
Name                     ips        average  deviation         median         99th %
exla cpu batch          1.34     0.0125 min     ±4.21%     0.0123 min     0.0132 min
exla cpu                0.25     0.0668 min     ±4.29%     0.0668 min     0.0688 min
nx                   0.00294       5.67 min     ±0.00%       5.67 min       5.67 min
nx batch             0.00271       6.14 min     ±0.00%       6.14 min       6.14 min
Comparison: 
exla cpu batch          1.34
exla cpu                0.25 - 5.35x slower +0.0543 min
nx                   0.00294 - 454.65x slower +5.66 min
nx batch             0.00271 - 492.39x slower +6.13 min
backendをXLAをCPUモードにしただけで劇的に速度が上がりました!
GPUを使えばもっと早くなるので今後が楽しみです
全体のコードはこちらになります
https://github.com/thehaigo/nx_sample
https://github.com/thehaigo/nx_dl
参考ページ
https://github.com/elixir-nx/nx/tree/main/nx
https://github.com/elixir-nx/nx/tree/main/exla
https://github.com/elixir-nx/nx/blob/main/exla/bench/softmax.exs

