16
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ElixirAdvent Calendar 2022

Day 10

Livebookで試しながら作る、はじめてのAxonプログラム

Last updated at Posted at 2022-11-22

Axonのニューラルネットの学習の様子を一次関数(y=ax+b)で試して見たときのLivebookです

Axonで線形回帰

Mix.install(
  [
    {:nx, "~> 0.3.0"},
    {:exla, "~> 0.3.0"},
    {:kino_vega_lite, "~> 0.1.3"},
    {:axon, "~> 0.2"}
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend,
      default_defn_options: [compiler: EXLA]
    ]
  ],
  system_env: [
    XLA_TARGET: "cpu"
  ]
)

modelについて

Axon.dense は、input * kernel + bias の計算をする

model =
  Axon.input("x", shape: {nil, 1})
  |> Axon.dense(1, name: "L1")
model_state = %{
  "L1" => %{
    "bias" => Nx.tensor([2]),
    "kernel" => Nx.tensor([3])
  }
}

out = Axon.predict(model, model_state, %{"x" => Nx.tensor([[1]])})

正解データ作成

batch = fn ->
  x = Nx.tensor(for _ <- 1..100, do: [:rand.uniform()])

  y =
    Nx.multiply(x, 2)
    |> Nx.add(0.5)
    |> Nx.add(Nx.tensor(for _ <- 1..100, do: [:rand.uniform()]) |> Nx.multiply(0.5))

  {x, y}
end

x,yは、それぞれ、{100,1}の行列

この例では、入力1次元、出力1次元のモデルなので{100,1}の行列

{x, y} = batch.()
defmodule View do
  def plotxy(x1, y1, x2 \\ [], y2 \\ []) do
    VegaLite.new(width: 600, height: 600)
    |> VegaLite.layers([
      VegaLite.new()
      |> VegaLite.data_from_values(x: x1, y: y1)
      |> VegaLite.mark(:point, tooltip: true)
      |> VegaLite.encode_field(:x, "x", type: :quantitative)
      |> VegaLite.encode_field(:y, "y", type: :quantitative),
      VegaLite.new()
      |> VegaLite.data_from_values(x: x2, y: y2)
      |> VegaLite.mark(:line)
      |> VegaLite.encode_field(:x, "x", type: :quantitative)
      |> VegaLite.encode_field(:y, "y", type: :quantitative)
    ])

    # |> VegaLite.Export.save!("result.vl.json")
    # |> VegaLite.Viewer.show()
  end
end
View.plotxy(
  Nx.to_flat_list(x),
  Nx.to_flat_list(y)
)
train_model = fn model, data, epochs ->
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
end

Axonが提供している損失関数(loss function)

  • binary_cross_entropy/3
  • categorical_cross_entropy/3
  • categorical_hinge/3
  • connectionist_temporal_classification/3
  • cosine_similarity/3
  • hinge/3
  • kl_divergence/3
  • log_cosh/3
  • margin_ranking/3
  • mean_absolute_error/3
  • mean_squared_error/3
  • poisson/3
  • soft_margin/3

Axonが提供しているオプティマイザー

  • adabelief/2
  • adagrad/2
  • adam/2
  • adamw/2
  • lamb/2
  • noisy_sgd/2
  • radam/2
  • rmsprop/2
  • sgd/2
  • yogi/2
input_data_100 = batch.()
# 毎回同じデータを返す場合は↓
data = Stream.repeatedly(fn -> input_data_100 end)
# 毎回新しいデータを作成する場合は、↓
# data = Stream.repeatedly(&batch/0)

# 10 epochトレーニングする
model_state =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.run(data, %{}, epochs: 10, iterations: 1000)
x2_tensor = Nx.tensor(for i <- 0..100, do: [i / 100])

y2_tensor = Axon.predict(model, model_state, %{"x" => x2_tensor})
View.plotxy(
  Nx.to_flat_list(x),
  Nx.to_flat_list(y),
  Nx.to_flat_list(x2_tensor),
  Nx.to_flat_list(y2_tensor)
)
defmodule Basic do
  require Axon

  def build_model(input_shape) do
    inp1 = Axon.input("x", shape: input_shape)

    inp1
    |> Axon.dense(1)
  end

  defp batch do
    x = Nx.tensor(for _ <- 1..100, do: [:rand.uniform()])

    y =
      Nx.multiply(x, 2)
      |> Nx.add(0.5)
      |> Nx.add(Nx.tensor(for _ <- 1..100, do: [:rand.uniform()]) |> Nx.multiply(0.5))

    {x, y}
  end

  defp train_model(model, data, epochs) do
    model
    |> Axon.Loop.trainer(:mean_squared_error, :sgd)
    |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
  end

  def run() do
    model = build_model({nil})
    input_data_100 = batch()
    data = Stream.repeatedly(fn -> input_data_100 end)
    # data = Stream.repeatedly(&batch/0)
    model_state = train_model(model, data, 10) |> IO.inspect(label: "model_state")

    {x1_tensor, y1_tensor} = input_data_100

    x2_tensor = Nx.tensor(for i <- 0..100, do: [i / 100])

    y2_tensor = Axon.predict(model, model_state, %{"x" => x2_tensor})

    View.plotxy(
      Nx.to_flat_list(x1_tensor),
      Nx.to_flat_list(y1_tensor),
      Nx.to_flat_list(x2_tensor),
      Nx.to_flat_list(y2_tensor)
    )
  end
end
Basic.run()
16
4
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?