9
2
お題は不問!Qiita Engineer Festa 2024で記事投稿!
Qiita Engineer Festa20242024年7月17日まで開催中!

2024/06/21 現在 Axon を使う場合の注意(GitHubから最新を取得すること)

Last updated at Posted at 2024-06-21

はじめに

久しぶりに Livebook で Axon を使ってみたら上手く動かなかったので共有します

実装したノートブックはこちら

実行したコード

動かしたのは以下の記事で実行していたコードです

OR 演算を Axon で学習しています

モジュールのインストール

2024/06/21 現在、 Axon の最新版は 0.6.1 になっています

以下のように各モジュール最新版を使うように指定します

Mix.install([
  {:nx, "~> 0.7"},
  {:axon, "~> 0.6"},
  {:exla, "~> 0.7"},
  {:kino, "~> 0.13"},
  {:kino_vega_lite, "~> 0.1"}
])

インストールされるものは以下の通りです

  axon 0.6.1
  complex 0.5.0
  elixir_make 0.8.4
  exla 0.7.2
  fss 0.1.1
  kino 0.13.0
  kino_vega_lite 0.1.11
  nimble_pool 1.1.0
  nx 0.7.2
  polaris 0.1.0
  table 0.1.2
  telemetry 1.2.1
  vega_lite 0.1.9
  xla 0.6.0

Axon のマクロを使うため、 require Axon しておきます

require Axon

学習データの生成

元の記事と同じコードです

generate_train_data = fn ->
  inputs =
    1..2
    |> Enum.into(%{}, fn index ->
      {
        "input#{index}",
        1..32
        |> Enum.map(fn _ -> Enum.random(0..1) end)
        |> Nx.tensor()
        |> Nx.new_axis(1)
      }
    end)

  labels = Nx.logical_or(inputs["input1"], inputs["input2"])

  {inputs, labels}
end
generate_train_data.()
train_data =
  generate_train_data
  |> Stream.repeatedly()
  |> Enum.take(1000)

これで、ランダムな 0 1 の入力二つに対して、 OR 演算した結果の出力という組み合わせが 1000 個作成できました

モデル定義

ここも元のコードと同じです

input1 = Axon.input("input1", shape: {nil, 1})
input2 = Axon.input("input2", shape: {nil, 1})
model =
  Axon.concatenate(input1, input2)
  |> Axon.dense(8, activation: :relu)
  |> Axon.dense(1, activation: :sigmoid)

トレーニング

損失と正解率のグラフ出力エリアを用意します

loss_plot =
  VegaLite.new(width: 300)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "step", type: :quantitative)
  |> VegaLite.encode_field(:y, "loss", type: :quantitative)
  |> Kino.VegaLite.new()

acc_plot =
  VegaLite.new(width: 300)
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "step", type: :quantitative)
  |> VegaLite.encode_field(:y, "accuracy", type: :quantitative)
  |> Kino.VegaLite.new()

Kino.Layout.grid([loss_plot, acc_plot], columns: 2)

トレーニングを実行します

trained_state =
  model
  |> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
  |> Axon.Loop.metric(:accuracy, "accuracy")
  |> Axon.Loop.kino_vega_lite_plot(loss_plot, "loss", event: :epoch_completed)
  |> Axon.Loop.kino_vega_lite_plot(acc_plot, "accuracy", event: :epoch_completed)
  |> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 1000, compiler: EXLA)

これを実行すると、問題が発生します

トレーニングで損失が増加する

以前の実行結果(正しいもの)は以下のように損失が減少し、正解率が上昇しています

  • 損失

    loss_graph.png

  • 正解率

    acc_graph.png

ところが、 Axon 0.6.1 を使うと以下のように損失が上昇し、正解率が減少しています

  • 損失

    LOSS

  • 正解率

    ACC

これでは全くトレーニングできていません

対処

この現象は Axon 0.6.1 のバグです

GitHub の Issue にも登録されていて 2024/05/14 に解消済ですが、反映されたバージョンがまだ Hex にリリースされていません

そのため、 Axon は GitHub から最新版を取得する必要があります

先頭のセットアップセルを以下のように書き換えてください

Mix.install([
  {:nx, "~> 0.7"},
  {:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
  {:exla, "~> 0.7"},
  {:kino, "~> 0.13"},
  {:kino_vega_lite, "~> 0.1"}
])

差分

-  {:axon, "~> 0.6"},
+  {:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},

また、トレーニングのコードも以下のように変更してください

初期状態として Axon.ModelState.empty() を渡すようになっています( %{} でも警告されるだけでエラーにはならない)

trained_state =
  model
  |> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
  |> Axon.Loop.metric(:accuracy, "accuracy")
  |> Axon.Loop.kino_vega_lite_plot(loss_plot, "loss", event: :epoch_completed)
  |> Axon.Loop.kino_vega_lite_plot(acc_plot, "accuracy", event: :epoch_completed)
  |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 5, iterations: 1000, compiler: EXLA)

差分

-  |> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 1000, compiler: EXLA)
+  |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 5, iterations: 1000, compiler: EXLA)

この状態で実行すると、以前と同じように正しくトレーニングできます

Explorer と併用する際の注意

以下のように Axon の GitHub 最新版と Explorer の Hex 最新版を併用しようとすると、依存関係のエラーが発生します

Mix.install([
  {:exla, "~> 0.7"},
  {:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
  {:kino, "~> 0.13"},
  {:kino_vega_lite, "~> 0.1"},
  {:explorer, "~> 0.8"}
])

エラー

Unchecked dependencies for environment dev:
* table_rex (Hex package)
  the dependency table_rex 4.0.0

  > In deps/explorer/mix.exs:
    {:table_rex, "~> 3.1.1 or ~> 4.0.0", [env: :prod, hex: "table_rex", repo: "hexpm", optional: false]}

  does not match the requirement specified

  > In deps/axon/mix.exs:
    {:table_rex, "~> 3.1.1", [env: :prod, hex: "table_rex", optional: true, repo: "hexpm"]}

  Ensure they match or specify one of the above in your deps and set "override: true"

エラーメッセージにある通り、 "override: true" を指定して table_rex をインストールすればエラーが解消されます

Mix.install([
  {:exla, "~> 0.7"},
  {:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
  {:kino, "~> 0.13"},
  {:kino_vega_lite, "~> 0.1"},
  {:explorer, "~> 0.8"},
  {:table_rex, "~> 4.0", override: true}
])

まとめ

モジュールのバージョンが上がったときは注意が必要です

GitHub の Issues などを見に行き、バグが報告されていないか確認しましょう

9
2
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
2