LoginSignup
6
2

More than 1 year has passed since last update.

Axonを使って線形回帰のパラメータ(傾きと切片)の値を求めてみよう

Last updated at Posted at 2022-09-22

Axon のHello world

AxonをLiveBookを使っていろいろ試してみたのですが、うまくいかず、苦労しました。

最も簡単な、y = ax + b にフィッテングしてみました。

処理の流れ

y = 2 * x + 0.5 + 乱数 の{x,y}の組を100個つくる

train_model()で、フィッティングする

結果表示

LiveBookで実行したコード

Setup

Mix.install([
  {:nx, "~> 0.3.0"},
  {:kino_vega_lite, "~> 0.1.3"},
  {:axon, "~> 0.2"},
])

本体

defmodule Basic do
  require Axon

  def build_model(input_shape) do
    inp1 = Axon.input("x", shape: input_shape)

    inp1
    |> Axon.dense(1)
  end

  defp batch do
    x = Nx.tensor(for _ <-1..100, do: [:rand.uniform()])
    y = Nx.multiply(x, 2)
     |> Nx.add(0.5)
     |> Nx.add(Nx.tensor(for _ <-1..100, do: [:rand.uniform()])|>Nx.multiply(0.5))
    {x, y}
  end

  defp train_model(model, data, epochs) do
    model
    |> Axon.Loop.trainer(:mean_squared_error, :sgd)
    |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
  end

  def run() do
    model = build_model({nil})
    data100 = batch()
    data = Stream.repeatedly(fn -> data100 end)
    # data = Stream.repeatedly(&batch/0)
    model_state = train_model(model, data, 10) |> IO.inspect(label: "model_state")

    result = Axon.predict(model, model_state, %{"x" => Nx.tensor([[0, 1]])|>Nx.transpose()   })
    {data100, result}
  end
end

結果表示

{{x,y},predict_y} = Basic.run()

VegaLite.new(width: 600, height: 600)
|> VegaLite.layers([
  VegaLite.new()
  |> VegaLite.data_from_values(x: x |> Nx.to_flat_list(), y: y |> Nx.to_flat_list())
  |> VegaLite.mark(:point, tooltip: true)
  |> VegaLite.encode_field(:x, "x", type: :quantitative)
  |> VegaLite.encode_field(:y, "y", type: :quantitative),
  VegaLite.new()
  |> VegaLite.data_from_values(x: [0, 1], y: Nx.to_flat_list(predict_y))
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "x", type: :quantitative)
  |> VegaLite.encode_field(:y, "y", type: :quantitative)
])

実行結果

image.png

model_stateの値

model_state: %{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.6768180131912231]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][1]
      [
        [2.0908796787261963]
      ]
    >
  }
}

切片のほうは乱数のばらつきで多少誤差がありますが、ほぼ正しい値が得られてる。

やっと入口にだとりついた。

参考
https://hexdocs.pm/axon/0.2.0/multi_input_example.html#everything-together

6
2
5

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2