はじめに
久しぶりに Livebook で Axon を使ってみたら上手く動かなかったので共有します
実装したノートブックはこちら
実行したコード
動かしたのは以下の記事で実行していたコードです
OR 演算を Axon で学習しています
モジュールのインストール
2024/06/21 現在、 Axon の最新版は 0.6.1 になっています
以下のように各モジュール最新版を使うように指定します
Mix.install([
{:nx, "~> 0.7"},
{:axon, "~> 0.6"},
{:exla, "~> 0.7"},
{:kino, "~> 0.13"},
{:kino_vega_lite, "~> 0.1"}
])
インストールされるものは以下の通りです
axon 0.6.1
complex 0.5.0
elixir_make 0.8.4
exla 0.7.2
fss 0.1.1
kino 0.13.0
kino_vega_lite 0.1.11
nimble_pool 1.1.0
nx 0.7.2
polaris 0.1.0
table 0.1.2
telemetry 1.2.1
vega_lite 0.1.9
xla 0.6.0
Axon のマクロを使うため、 require Axon しておきます
require Axon
学習データの生成
元の記事と同じコードです
generate_train_data = fn ->
inputs =
1..2
|> Enum.into(%{}, fn index ->
{
"input#{index}",
1..32
|> Enum.map(fn _ -> Enum.random(0..1) end)
|> Nx.tensor()
|> Nx.new_axis(1)
}
end)
labels = Nx.logical_or(inputs["input1"], inputs["input2"])
{inputs, labels}
end
generate_train_data.()
train_data =
generate_train_data
|> Stream.repeatedly()
|> Enum.take(1000)
これで、ランダムな 0 1 の入力二つに対して、 OR 演算した結果の出力という組み合わせが 1000 個作成できました
モデル定義
ここも元のコードと同じです
input1 = Axon.input("input1", shape: {nil, 1})
input2 = Axon.input("input2", shape: {nil, 1})
model =
Axon.concatenate(input1, input2)
|> Axon.dense(8, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
トレーニング
損失と正解率のグラフ出力エリアを用意します
loss_plot =
VegaLite.new(width: 300)
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "step", type: :quantitative)
|> VegaLite.encode_field(:y, "loss", type: :quantitative)
|> Kino.VegaLite.new()
acc_plot =
VegaLite.new(width: 300)
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "step", type: :quantitative)
|> VegaLite.encode_field(:y, "accuracy", type: :quantitative)
|> Kino.VegaLite.new()
Kino.Layout.grid([loss_plot, acc_plot], columns: 2)
トレーニングを実行します
trained_state =
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy, "accuracy")
|> Axon.Loop.kino_vega_lite_plot(loss_plot, "loss", event: :epoch_completed)
|> Axon.Loop.kino_vega_lite_plot(acc_plot, "accuracy", event: :epoch_completed)
|> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 1000, compiler: EXLA)
これを実行すると、問題が発生します
トレーニングで損失が増加する
以前の実行結果(正しいもの)は以下のように損失が減少し、正解率が上昇しています
ところが、 Axon 0.6.1 を使うと以下のように損失が上昇し、正解率が減少しています
これでは全くトレーニングできていません
対処
この現象は Axon 0.6.1 のバグです
GitHub の Issue にも登録されていて 2024/05/14 に解消済ですが、反映されたバージョンがまだ Hex にリリースされていません
そのため、 Axon は GitHub から最新版を取得する必要があります
先頭のセットアップセルを以下のように書き換えてください
Mix.install([
{:nx, "~> 0.7"},
{:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
{:exla, "~> 0.7"},
{:kino, "~> 0.13"},
{:kino_vega_lite, "~> 0.1"}
])
差分
- {:axon, "~> 0.6"},
+ {:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
また、トレーニングのコードも以下のように変更してください
初期状態として Axon.ModelState.empty() を渡すようになっています( %{} でも警告されるだけでエラーにはならない)
trained_state =
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.metric(:accuracy, "accuracy")
|> Axon.Loop.kino_vega_lite_plot(loss_plot, "loss", event: :epoch_completed)
|> Axon.Loop.kino_vega_lite_plot(acc_plot, "accuracy", event: :epoch_completed)
|> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 5, iterations: 1000, compiler: EXLA)
差分
- |> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 1000, compiler: EXLA)
+ |> Axon.Loop.run(train_data, Axon.ModelState.empty(), epochs: 5, iterations: 1000, compiler: EXLA)
この状態で実行すると、以前と同じように正しくトレーニングできます
Explorer と併用する際の注意
以下のように Axon の GitHub 最新版と Explorer の Hex 最新版を併用しようとすると、依存関係のエラーが発生します
Mix.install([
{:exla, "~> 0.7"},
{:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
{:kino, "~> 0.13"},
{:kino_vega_lite, "~> 0.1"},
{:explorer, "~> 0.8"}
])
エラー
Unchecked dependencies for environment dev:
* table_rex (Hex package)
the dependency table_rex 4.0.0
> In deps/explorer/mix.exs:
{:table_rex, "~> 3.1.1 or ~> 4.0.0", [env: :prod, hex: "table_rex", repo: "hexpm", optional: false]}
does not match the requirement specified
> In deps/axon/mix.exs:
{:table_rex, "~> 3.1.1", [env: :prod, hex: "table_rex", optional: true, repo: "hexpm"]}
Ensure they match or specify one of the above in your deps and set "override: true"
エラーメッセージにある通り、 "override: true" を指定して table_rex をインストールすればエラーが解消されます
Mix.install([
{:exla, "~> 0.7"},
{:axon, "~> 0.6", git: "https://github.com/elixir-nx/axon/"},
{:kino, "~> 0.13"},
{:kino_vega_lite, "~> 0.1"},
{:explorer, "~> 0.8"},
{:table_rex, "~> 4.0", override: true}
])
まとめ
モジュールのバージョンが上がったときは注意が必要です
GitHub の Issues などを見に行き、バグが報告されていないか確認しましょう



