はじめに
本記事はElixirで機械学習/ディープラーニングができるようになるnumpy likeなライブラリ Nxを使って
「ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装」
をElixirで書いていこうという記事になります。
今回は4,5章で使用した Nxのgrad関数が何を行っているのかを読んでみた時のメモ書き的なものになります
準備編
exla setup
1章 pythonの基本 -> とばします
2章 パーセプトロン -> とばします
3章 ニューラルネットワーク
with exla
4章 ニューラルネットワークの学習
5章 誤差逆伝播法
Nx.Defn.Kernel.grad
6章 学習に関するテクニック -> とばします
7章 畳み込みニューラルネットワーク
わかったこと
読んでみたところ、
forwardを実行
forwardの結果を保持(cache)
grad関数で実行されるdefn内の処理を解析しbackward部分を自動生成(予想)
backwrad部分は各関数ごとにgrad.exやexpr.ex内に記載されている
ということがわかりました
grad関数というかnxやばいです
Nx.Defn.Karnel.grad
https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/defn/kernel.ex
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/kernel.ex#L231-L234
第2引数が関数かチェックしてtransformを呼んでいます
def grad(var_or_vars, fun) when is_function(fun, 1) do
    {_value, grad} = Nx.Defn.Grad.transform(var_or_vars, fun)
    grad
end
Nx.Defn.Grad.transform
https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/defn/grad.ex
  def transform(to_grad, fun) do
    # 1
    {to_grad, ids} =
      Tree.composite(to_grad, %{}, fn to_grad, ids ->
        validate_grad!(to_grad)
        to_grad = Expr.metadata(to_grad, %{__MODULE__ => :to_grad})
        {to_grad, Map.put(ids, to_grad.data.id, :to_grad)}
      end)
    # 2
    expr = to_grad |> fun.() |> validate_expr!()
    # Collect all IDs in the function environment and mark
    # them as stop grads. This is an optimization to avoid
    # traversing trees when not necessary.
    # 3
    {:env, env} = Function.info(fun, :env)
    ids = stop_grads(env, ids)
    # 4
    # Grad all the parameters at the same time to share subtrees.
    {graded, _} = to_grad(expr, Expr.tensor(1.0), {ids, %{}})
    # 5
    # Now traverse the expression again zerofying
    # the parts that comes from other variables.
    # We do so by encoding special nodes in the Expr
    # AST and unpack them as we verify.
    graded =
      Tree.composite(to_grad, fn to_grad ->
        id = to_grad.data.id
        {graded, _, _} = zerofy_ids(graded, %{}, Map.delete(ids, id))
        if graded.shape == to_grad.shape do
          graded
        else
          Nx.broadcast(graded, to_grad)
        end
      end)
    {expr, graded}
  end
1 to_gradの中のTensorにユニークなIDを付けて、Tensor自体とidのリストを返す
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/tree.ex#L85-L88
引数がTensorかを確認して第3引数の関数を実行
  def composite(tuple, acc, fun) when is_tuple(tuple) and is_function(fun, 2) do
    {list, acc} = Enum.map_reduce(Tuple.to_list(tuple), acc, &composite(&1, &2, fun))
    {List.to_tuple(list), acc}
  end
  def composite(%T{} = expr, acc, fun) when is_function(fun, 2) do
    fun.(expr, acc)
  end
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/expr.ex#L83-L86
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/expr.ex#L701-L706
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/expr.ex#L190
https://hexdocs.pm/elixir/Kernel.html#make_ref/0
メタデータを作成
  #L83-L86
  def metadata(expr, metadata) when is_map(metadata) do
    expr = to_expr(expr)
    expr(expr, expr.data.context, :metadata, [expr, metadata])
  end
  #L701-L706
  defp expr(tensor, context, op, args) do
    %{tensor | data: %Expr{id: id(), op: op, args: args, context: context}}
  end
  defp to_expr(%T{data: %Expr{}} = t),
    do: t
  # l190
  def id(), do: make_ref()
2 functionの実行結果をメタデータ付きtensorかをチェックする
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/grad.ex#L53-L55
functionの実行結果をメタデータ付きtensorかをチェックする
  expr = to_grad |> fun.() |> validate_expr!()
  defp validate_expr!(%T{data: %Expr{}, shape: {}} = expr) do
    expr
  end
3 imageとlabelを更新されないように固定する
https://hexdocs.pm/elixir/Function.html#info/2
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/grad.ex#L67-L83
Function.info(func,:env)を実行すると以下のようなデータが出力される
x_test => mnist test data 10000x784
t_test => mnist test label 10000x10
iex(17)> Function.info(&Ch5.TwoLayerNet.loss_g(&1,x_test,t_test,100),:env)    
{:env,
 [
   {[_@4: #Nx.Tensor<
        s64[10000][10]
        [
          [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
          [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
          [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 1, ...],
          ...
        ]
>, _@5: #Nx.Tensor<
        f32[10000][784]
        [
          [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
          ...
        ]
>], :none, :none,
    [
      {:clause, 17, [{:var, 0, :_@7}], [],
       [
         {:call, 17,
          {:remote, 17, {:atom, 0, Ch5.TwoLayerNet}, {:atom, 17, :loss_g}},
          [
            {:var, 0, :_@7},
            {:var, 17, :_@5},
            {:var, 17, :_@4},
            {:integer, 0, 100}
          ]}
       ]}
    ]}
 ]}
これに対してstop_gradsを行うと最終的に(Tensor,ids)か(_,ids) になるので
Tensorのメタデータにあるidをキーにidsにstopフラグを立てる
  {:env, env} = Function.info(fun, :env)
  ids = stop_grads(env, ids)
  # L67-L83
  defp stop_grads(list, ids) when is_list(list),
    do: Enum.reduce(list, ids, &stop_grads/2)
  defp stop_grads(tuple, ids) when is_tuple(tuple),
    do: tuple |> Tuple.to_list() |> Enum.reduce(ids, &stop_grads/2)
  defp stop_grads(%T{data: %Expr{id: id}}, ids),
    do: Map.put(ids, id, :stop)
  defp stop_grads(%_{}, ids),
    do: ids
  defp stop_grads(map, ids) when is_map(map),
    do: map |> Map.values() |> Enum.reduce(ids, &stop_grads/2)
  defp stop_grads(_, ids),
    do: ids
4 再帰的なGrad
https://github.com/elixir-nx/nx/blob/a9f7ea8fc09483e5b65783281dd2b459bcda11ae/nx/lib/nx/defn/grad.ex#L160
Expr.tensor(1.0) はbackwardのdout = 1 ?
expr = {w1,b1,w2,b2}のidはMapのkeyにないので、最後のgrad(op, args, ans, res cache)を実行
  {graded, _} = to_grad(expr, Expr.tensor(1.0), {ids, %{}})
  ## Recursion
  # The gradient recursion.
  #
  # We keep two caches. One is the result cache, which is used for
  # when visiting the same nodes in the AST.
  #
  # The other cache is the JVP cache, that shares parts of the JVP
  # computation. Both are important to reduce the amount of nodes
  # in the AST.
  defp to_grad(expr, res, cache) do
    Tree.composite(expr, cache, fn
      %T{data: %Expr{id: id, op: op, args: args}} = ans, {result_cache, jvp_cache} = cache ->
        key = [id | res.data.id]
        case result_cache do
          %{^id => :stop} ->
            {Expr.tensor(0.0), cache}
          %{^id => :to_grad} ->
            {Expr.metadata(res, %{__MODULE__ => {:tainted, id}}), cache}
          %{^key => res} ->
            {res, cache}
          %{} ->
            case grad(op, args, ans, res, cache) do
              {res, {result_cache, jvp_cache}} ->
                {res, {Map.put(result_cache, key, res), jvp_cache}}
              :none ->
                jvps =
                  case jvp_cache do
                    %{^id => jvps} -> jvps
                    %{} -> jvp(op, args, ans)
                  end
                {res, {result_cache, jvp_cache}} = grad_jvps(jvps, ans, res, cache)
                {res, {Map.put(result_cache, key, res), Map.put(jvp_cache, id, jvps)}}
            end
        end
    end)
  end
iex(40)> %T{data: %Expr{id: id, op: op, args: args}} = Expr.tensor(w1)
# Nx.Tensor<
  f64[784][100]
  
  Nx.Defn.Expr
  tensor a  f64[784][100]
>
iex(41)> id
# Reference<0.3878774584.611057665.71016>
iex(42)> op
:tensor
iex(43)> args
[#Nx.Tensor<
    f64[784][100]
    [
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
      ...
    ]
>]
https://github.com/elixir-nx/nx/blob/85a3d3a4d3eab44c588d05a8f32c25c952ba4787/nx/lib/nx/defn/grad.ex#L479-L481
op は:tensorなので、:noneを返す
  defp grad(_op, _args, _ans, _g, _cache) do
    :none
  end
https://github.com/elixir-nx/nx/blob/85a3d3a4d3eab44c588d05a8f32c25c952ba4787/nx/lib/nx/defn/grad.ex#L748
:tensorなので []を返す
  defp jvp(op, _, _) when op in @constants do
    []
  end
[] なので {tensor(0.0),cache} を返す
defp grad_jvps([], _ans, _g, cache), do: {Expr.tensor(0.0), cache}
これだと何も実行されていないので
多分opが実行した関数がスタックされて
Nx.{ add, subtract, multiply, divide} なら jvps
それ以外なら gradで再帰的に微分を行っていると思われる
以下を見ると [{in,out},{in,out}]として見ると前の記事の乗算レイヤのbackwardと同じ形を取っているのがわかる
https://github.com/elixir-nx/nx/blob/85a3d3a4d3eab44c588d05a8f32c25c952ba4787/nx/lib/nx/defn/grad.ex#L493-L496
  defp jvp(:multiply, [x, y], _ans) do
    [{x, y}, {y, x}]
  end
これらを鑑みると、grad関数は実行されたdefn内のコードを解析してbackward部分を自動的に生成している、誤差逆伝播法ではないかと思われる。
5 初期化
最後は次の学習に影響しないようにidsリストを初期化してると思われるので、割愛
本記事は以上になりますありがとうございました

