Axon のHello world
AxonをLiveBookを使っていろいろ試してみたのですが、うまくいかず、苦労しました。
最も簡単な、y = ax + b にフィッテングしてみました。
処理の流れ
y = 2 * x + 0.5 + 乱数 の{x,y}の組を100個つくる
↓
train_model()で、フィッティングする
↓
結果表示
LiveBookで実行したコード
Setup
Mix.install([
{:nx, "~> 0.3.0"},
{:kino_vega_lite, "~> 0.1.3"},
{:axon, "~> 0.2"},
])
本体
defmodule Basic do
require Axon
def build_model(input_shape) do
inp1 = Axon.input("x", shape: input_shape)
inp1
|> Axon.dense(1)
end
defp batch do
x = Nx.tensor(for _ <-1..100, do: [:rand.uniform()])
y = Nx.multiply(x, 2)
|> Nx.add(0.5)
|> Nx.add(Nx.tensor(for _ <-1..100, do: [:rand.uniform()])|>Nx.multiply(0.5))
{x, y}
end
defp train_model(model, data, epochs) do
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
end
def run() do
model = build_model({nil})
data100 = batch()
data = Stream.repeatedly(fn -> data100 end)
# data = Stream.repeatedly(&batch/0)
model_state = train_model(model, data, 10) |> IO.inspect(label: "model_state")
result = Axon.predict(model, model_state, %{"x" => Nx.tensor([[0, 1]])|>Nx.transpose() })
{data100, result}
end
end
結果表示
{{x,y},predict_y} = Basic.run()
VegaLite.new(width: 600, height: 600)
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.data_from_values(x: x |> Nx.to_flat_list(), y: y |> Nx.to_flat_list())
|> VegaLite.mark(:point, tooltip: true)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative),
VegaLite.new()
|> VegaLite.data_from_values(x: [0, 1], y: Nx.to_flat_list(predict_y))
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative)
])
実行結果
model_stateの値
model_state: %{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.6768180131912231]
>,
"kernel" => #Nx.Tensor<
f32[1][1]
[
[2.0908796787261963]
]
>
}
}
切片のほうは乱数のばらつきで多少誤差がありますが、ほぼ正しい値が得られてる。
やっと入口にだとりついた。
参考
https://hexdocs.pm/axon/0.2.0/multi_input_example.html#everything-together