LoginSignup
13
7

More than 1 year has passed since last update.

Nxで始めるゼロから作るディープラーニング 3章 ニューラルネットワーク with Exla

Last updated at Posted at 2021-02-24

本記事はElixirで機械学習/ディープラーニングができるようになるnumpy likeなライブラリ Nxを使って
ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装
をElixirで書いていこうという記事になります。

今回は前回の記事をgoogle XLAをbackendで使うExlaで高速化していきます

準備編
exla setup
1章 pythonの基本 -> とばします
2章 パーセプトロン -> とばします
3章 ニューラルネットワーク
with exla
4章 ニューラルネットワークの学習
5章 誤差逆伝播法
Nx.Defn.Kernel.grad
6章 学習に関するテクニック -> とばします
7章 畳み込みニューラルネットワーク

install exla

※2022/04/14 追記
以下を追加してdeps.getを実行してください

mix.exs
def deps do
  [
    {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
    {:nx, "~> 0.1", github: "elixir-nx/nx", sparse: "nx", override: true}
  ]
end

現在はコンパイル済みバイナリを提供するようなったので、以下の手順は不要です

https://github.com/elixir-nx/nx/tree/main/exla#common-installation-issues
documentにかいてあるとおりに
コンパイルに使うbazelをインストールして

asdf plugin-add bazel
asdf install bazel 3.1.0
asdf global bazel 3.1.0

mix.exsにexlaとnxをいれて
mix deps.get
mix deps.compile
でインストールします

origin/master' did not match any file(s) known to git
とエラーがでたら branch: "main"オプションを追加してください
xlaコンパイルに30分以上かかりますので気長に待ちましょう

mix.exs
def deps do
  [
    {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
    {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}
  ]
end

exla 有効化

exlaはdefnで定義した関数をxlaで実行します
defn内でMapは使えないので注意が必要です

mnist_nx.ex
defmodule MnistNx do
  import Nx.Defn
  @defn_compiler {EXLA, max_float_type: {:f, 64}} # これ追加
  defn predict(x, {w1,w2,w3,b1,b2,b3}) do
    x
    |> Nx.dot(w1)
    |> Nx.add(b1)
    |> sigmoid()
    |> Nx.dot(w2)
    |> Nx.add(b2)
    |> sigmoid()
    |> Nx.dot(w3)
    |> Nx.add(b3)
    |> softmax()
  end

  defn sigmoid(tensor) do
    1 / (1 + Nx.exp(-tensor))
  end

  defn softmax(tensor) do
    tensor = Nx.add(tensor, -Nx.reduce_max(tensor))
    tensor
    |> Nx.exp()
    |> Nx.divide( tensor |> Nx.exp() |> Nx.sum())
  end

  def acc(x,t,network) do
    {row, _} = Nx.shape(x)
    Enum.to_list(0..(row - 1))
    |> Nx.tensor()
    |> Nx.map(fn i ->
      predict(x[i],network)
      |> Nx.argmax()
      |> Nx.equal(t[i])
    end)
    |> Nx.sum
    |> Nx.divide(row)
  end

  def acc_batch(x,t,network) do
    x_b = x |> Nx.to_batched_list(100)
    t_b = t |> Nx.to_batched_list(100)
    {row, _} = x |> Nx.shape()
    batch = Enum.count(x_b)
    Enum.to_list(0..(batch-1))
    |> Nx.tensor()
    |> Nx.map(fn i ->
      predict(Enum.at(x_b,Nx.to_number(i)),network)
      |> Nx.argmax(axis: 1)
      |> Nx.equal(Enum.at(t_b,Nx.to_number(i)))
      |> Nx.sum()
    end)
    |> Nx.sum()
    |> Nx.divide(row)
  end
end

exlaのオプションはベンチ用のコードを参考にしましょう
https://github.com/elixir-nx/nx/blob/main/exla/bench/softmax.exs

ベンチ

bench code

nx_nn.ex
defmodule NxNn do
  def get_data() do
    x_test = Dataset.test_image() |> Nx.tensor() |> (& Nx.divide(&1, Nx.reduce_max(&1))).()
    t_test = Dataset.test_label() |> Nx.tensor()
    {x_test, t_test}
  end

  def init_network() do
    {w1,w2,w3,b1,b2,b3} = PklLoad.load("pkl/sample_weight.pkl")
    {
      Nx.tensor(w1),
      Nx.tensor(w2),
      Nx.tensor(w3),
      Nx.tensor(b1),
      Nx.tensor(b2),
      Nx.tensor(b3)
    }
  end

  def acc do
    {x, t} = get_data()
    IO.puts("Load Data")
    Benchee.run(%{
      "exla cpu" => fn -> MnistNx.acc(x, t, init_network()) end,
      "exla cpu batch" => fn -> MnistNx.acc_batch(x,t,init_network()) end,
      "nx" => fn -> Mnist.acc(x, t, init_network()) end,
      "nx batch" => fn ->  Mnist.acc(x,t,init_network()) end
    })
   end
end

Load Data
Operating System: macOS
CPU Information: Intel(R) Core(TM) i7-7660U CPU @ 2.50GHz
Number of Available Cores: 4
Available memory: 16 GB
Elixir 1.11.3
Erlang 23.2.5

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 28 s

Benchmarking exla cpu...
Benchmarking exla cpu batch...
Benchmarking nx...
Benchmarking nx batch...

Name                     ips        average  deviation         median         99th %
exla cpu batch          1.34     0.0125 min     ±4.21%     0.0123 min     0.0132 min
exla cpu                0.25     0.0668 min     ±4.29%     0.0668 min     0.0688 min
nx                   0.00294       5.67 min     ±0.00%       5.67 min       5.67 min
nx batch             0.00271       6.14 min     ±0.00%       6.14 min       6.14 min

Comparison: 
exla cpu batch          1.34
exla cpu                0.25 - 5.35x slower +0.0543 min
nx                   0.00294 - 454.65x slower +5.66 min
nx batch             0.00271 - 492.39x slower +6.13 min

backendをXLAをCPUモードにしただけで劇的に速度が上がりました!
GPUを使えばもっと早くなるので今後が楽しみです

全体のコードはこちらになります
https://github.com/thehaigo/nx_sample
https://github.com/thehaigo/nx_dl

参考ページ

https://github.com/elixir-nx/nx/tree/main/nx
https://github.com/elixir-nx/nx/tree/main/exla
https://github.com/elixir-nx/nx/blob/main/exla/bench/softmax.exs

13
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
7