LoginSignup
5
1

More than 3 years have passed since last update.

Nxで始めるゼロから作るディープラーニング Nx.Defn.Kernel.gradを読む

Last updated at Posted at 2021-03-29

はじめに

本記事はElixirで機械学習/ディープラーニングができるようになるnumpy likeなライブラリ Nxを使って
ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装
をElixirで書いていこうという記事になります。

今回は4,5章で使用した Nxのgrad関数が何を行っているのかを読んでみた時のメモ書き的なものになります

準備編
exla setup
1章 pythonの基本 -> とばします
2章 パーセプトロン -> とばします
3章 ニューラルネットワーク
with exla
4章 ニューラルネットワークの学習
5章 誤差逆伝播法
Nx.Defn.Kernel.grad
6章 学習に関するテクニック -> とばします
7章 畳み込みニューラルネットワーク

わかったこと

読んでみたところ、
forwardを実行
forwardの結果を保持(cache)
grad関数で実行されるdefn内の処理を解析しbackward部分を自動生成(予想)
backwrad部分は各関数ごとにgrad.exやexpr.ex内に記載されている
ということがわかりました
grad関数というかnxやばいです

Nx.Defn.Karnel.grad

https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/defn/kernel.ex
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/kernel.ex#L231-L234
第2引数が関数かチェックしてtransformを呼んでいます

def grad(var_or_vars, fun) when is_function(fun, 1) do
    {_value, grad} = Nx.Defn.Grad.transform(var_or_vars, fun)
    grad
end

Nx.Defn.Grad.transform
https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/defn/grad.ex

  def transform(to_grad, fun) do
    # 1
    {to_grad, ids} =
      Tree.composite(to_grad, %{}, fn to_grad, ids ->
        validate_grad!(to_grad)
        to_grad = Expr.metadata(to_grad, %{__MODULE__ => :to_grad})
        {to_grad, Map.put(ids, to_grad.data.id, :to_grad)}
      end)

    # 2
    expr = to_grad |> fun.() |> validate_expr!()

    # Collect all IDs in the function environment and mark
    # them as stop grads. This is an optimization to avoid
    # traversing trees when not necessary.
    # 3
    {:env, env} = Function.info(fun, :env)
    ids = stop_grads(env, ids)

    # 4
    # Grad all the parameters at the same time to share subtrees.
    {graded, _} = to_grad(expr, Expr.tensor(1.0), {ids, %{}})

    # 5
    # Now traverse the expression again zerofying
    # the parts that comes from other variables.
    # We do so by encoding special nodes in the Expr
    # AST and unpack them as we verify.
    graded =
      Tree.composite(to_grad, fn to_grad ->
        id = to_grad.data.id
        {graded, _, _} = zerofy_ids(graded, %{}, Map.delete(ids, id))

        if graded.shape == to_grad.shape do
          graded
        else
          Nx.broadcast(graded, to_grad)
        end
      end)

    {expr, graded}
  end

1 to_gradの中のTensorにユニークなIDを付けて、Tensor自体とidのリストを返す

https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/tree.ex#L85-L88
引数がTensorかを確認して第3引数の関数を実行

  def composite(tuple, acc, fun) when is_tuple(tuple) and is_function(fun, 2) do
    {list, acc} = Enum.map_reduce(Tuple.to_list(tuple), acc, &composite(&1, &2, fun))
    {List.to_tuple(list), acc}
  end

  def composite(%T{} = expr, acc, fun) when is_function(fun, 2) do
    fun.(expr, acc)
  end

https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/expr.ex#L83-L86
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/expr.ex#L701-L706
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/expr.ex#L190
https://hexdocs.pm/elixir/Kernel.html#make_ref/0
メタデータを作成

  #L83-L86
  def metadata(expr, metadata) when is_map(metadata) do
    expr = to_expr(expr)
    expr(expr, expr.data.context, :metadata, [expr, metadata])
  end
  #L701-L706
  defp expr(tensor, context, op, args) do
    %{tensor | data: %Expr{id: id(), op: op, args: args, context: context}}
  end

  defp to_expr(%T{data: %Expr{}} = t),
    do: t
  # l190
  def id(), do: make_ref()

2 functionの実行結果をメタデータ付きtensorかをチェックする

https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/grad.ex#L53-L55
functionの実行結果をメタデータ付きtensorかをチェックする

  expr = to_grad |> fun.() |> validate_expr!()

  defp validate_expr!(%T{data: %Expr{}, shape: {}} = expr) do
    expr
  end

3 imageとlabelを更新されないように固定する

https://hexdocs.pm/elixir/Function.html#info/2
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/grad.ex#L67-L83
Function.info(func,:env)を実行すると以下のようなデータが出力される
x_test => mnist test data 10000x784
t_test => mnist test label 10000x10

iex(17)> Function.info(&Ch5.TwoLayerNet.loss_g(&1,x_test,t_test,100),:env)    
{:env,
 [
   {[_@4: #Nx.Tensor<
        s64[10000][10]
        [
          [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
          [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
          [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 1, ...],
          ...
        ]
>, _@5: #Nx.Tensor<
        f32[10000][784]
        [
          [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
          ...
        ]
>], :none, :none,
    [
      {:clause, 17, [{:var, 0, :_@7}], [],
       [
         {:call, 17,
          {:remote, 17, {:atom, 0, Ch5.TwoLayerNet}, {:atom, 17, :loss_g}},
          [
            {:var, 0, :_@7},
            {:var, 17, :_@5},
            {:var, 17, :_@4},
            {:integer, 0, 100}
          ]}
       ]}
    ]}
 ]}

これに対してstop_gradsを行うと最終的に(Tensor,ids)か(_,ids) になるので
Tensorのメタデータにあるidをキーにidsにstopフラグを立てる

  {:env, env} = Function.info(fun, :env)
  ids = stop_grads(env, ids)

  # L67-L83
  defp stop_grads(list, ids) when is_list(list),
    do: Enum.reduce(list, ids, &stop_grads/2)

  defp stop_grads(tuple, ids) when is_tuple(tuple),
    do: tuple |> Tuple.to_list() |> Enum.reduce(ids, &stop_grads/2)

  defp stop_grads(%T{data: %Expr{id: id}}, ids),
    do: Map.put(ids, id, :stop)

  defp stop_grads(%_{}, ids),
    do: ids

  defp stop_grads(map, ids) when is_map(map),
    do: map |> Map.values() |> Enum.reduce(ids, &stop_grads/2)

  defp stop_grads(_, ids),
    do: ids

4 再帰的なGrad

https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/grad.ex#L160
Expr.tensor(1.0) はbackwardのdout = 1 ?
expr = {w1,b1,w2,b2}のidはMapのkeyにないので、最後のgrad(op, args, ans, res cache)を実行

  {graded, _} = to_grad(expr, Expr.tensor(1.0), {ids, %{}})
  ## Recursion

  # The gradient recursion.
  #
  # We keep two caches. One is the result cache, which is used for
  # when visiting the same nodes in the AST.
  #
  # The other cache is the JVP cache, that shares parts of the JVP
  # computation. Both are important to reduce the amount of nodes
  # in the AST.
  defp to_grad(expr, res, cache) do
    Tree.composite(expr, cache, fn
      %T{data: %Expr{id: id, op: op, args: args}} = ans, {result_cache, jvp_cache} = cache ->
        key = [id | res.data.id]

        case result_cache do
          %{^id => :stop} ->
            {Expr.tensor(0.0), cache}

          %{^id => :to_grad} ->
            {Expr.metadata(res, %{__MODULE__ => {:tainted, id}}), cache}

          %{^key => res} ->
            {res, cache}

          %{} ->
            case grad(op, args, ans, res, cache) do
              {res, {result_cache, jvp_cache}} ->
                {res, {Map.put(result_cache, key, res), jvp_cache}}

              :none ->
                jvps =
                  case jvp_cache do
                    %{^id => jvps} -> jvps
                    %{} -> jvp(op, args, ans)
                  end

                {res, {result_cache, jvp_cache}} = grad_jvps(jvps, ans, res, cache)
                {res, {Map.put(result_cache, key, res), Map.put(jvp_cache, id, jvps)}}
            end
        end
    end)
  end
iex(40)> %T{data: %Expr{id: id, op: op, args: args}} = Expr.tensor(w1)
#Nx.Tensor<
  f64[784][100]

  Nx.Defn.Expr
  tensor a  f64[784][100]
>
iex(41)> id
#Reference<0.3878774584.611057665.71016>
iex(42)> op
:tensor
iex(43)> args
[#Nx.Tensor<
    f64[784][100]
    [
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
      ...
    ]
>]

https://github.com/elixir-nx/nx/blob/85a3d3a4d3eab44c588d05a8f32c25c952ba4787/nx/lib/nx/defn/grad.ex#L479-L481
op は:tensorなので、:noneを返す

  defp grad(_op, _args, _ans, _g, _cache) do
    :none
  end

https://github.com/elixir-nx/nx/blob/85a3d3a4d3eab44c588d05a8f32c25c952ba4787/nx/lib/nx/defn/grad.ex#L748
:tensorなので []を返す

  defp jvp(op, _, _) when op in @constants do
    []
  end

[] なので {tensor(0.0),cache} を返す

defp grad_jvps([], _ans, _g, cache), do: {Expr.tensor(0.0), cache}

これだと何も実行されていないので
多分opが実行した関数がスタックされて
Nx.{ add, subtract, multiply, divide} なら jvps
それ以外なら gradで再帰的に微分を行っていると思われる

以下を見ると [{in,out},{in,out}]として見ると前の記事の乗算レイヤのbackwardと同じ形を取っているのがわかる
https://github.com/elixir-nx/nx/blob/85a3d3a4d3eab44c588d05a8f32c25c952ba4787/nx/lib/nx/defn/grad.ex#L493-L496

  defp jvp(:multiply, [x, y], _ans) do
    [{x, y}, {y, x}]
  end

これらを鑑みると、grad関数は実行されたdefn内のコードを解析してbackward部分を自動的に生成している、誤差逆伝播法ではないかと思われる。

5 初期化

最後は次の学習に影響しないようにidsリストを初期化してると思われるので、割愛

本記事は以上になりますありがとうございました

5
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1