6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Elixir Nx で学ぶ機械学習 - Axon無し (Livebook)

Last updated at Posted at 2022-11-10

「Elixir Nx の基礎 (Livebook)」 で見たように、Nxテンソル自動微分 をサポートしています。Nx は Python の NumPy の機能に加えて、PyTorch などの機械学習ライブラリの基本部分を既に持っていることになります。今回は、Axon を使わないで、Nx だけで機械学習を行ってみたいと思います。

「PyTorch & 深層学習プログラミング3章 (日経BP)」に 単回帰分析 の例があるので、それを Nx を使って実装していきたいと思います。機械学習は、一見すると魔法のようなテクニックですが、その本質は勾配計算(微分)にあると思います。このシンプルな例で、その本質がクッキリ表れていることを望みます。

単回帰分析というのは1つの目的変数を1つの説明変数で予測するもので、その2変量の間の関係性をY=aX+bという一次方程式の形で表します。a(傾き)とb(Y切片)がわかれば、X(身長)からY(体重)を予測することができるわけです。
「単回帰分析とは」

Elixir Axon 機械学習記事のまとめ

問題 記事 活性化関数 出力層 損失関数
単回帰 Elixir Nx で学ぶ - - MSE
2値分類 排他的論理和 tanh sigmoid binary_cross_entropy
多値分類 手書き数字認識 relu softmax categorical_cross_entropy

関連記事
Elixir Nx の基礎 (Livebook) - Qiita
Livebook 事始め (Elixir) - Qiita
【機械学習】誤差逆伝播法のコンパクトな説明 - Qiita

1. データ前処理

1-1. ライブラリのインストール

Mix.install([
  {:nx, "~> 0.4"},
  {:vega_lite, "~> 0.1.6"},
  {:kino_vega_lite, "~> 0.1.4"}
])

alias VegaLite, as: Vl

VegaLite はデータを可視化するのに使います。機械学習の本質部分には使いません。またこの程度のシンプルな機械学習なら Exla (GPU) は必要ありません。

1-2. 素データ

sample = [
    [166, 58.7],
    [176.0, 75.7],
    [171.0, 62.1],
    [173.0, 70.4],
    [169.0,60.1]
]

5人分の身長と体重のデータです。X(身長)からY(体重)を予測するため、その関係を表す Y=aX+b という一次方程式を求めたいと思います。

データの散布具合を見たいと思います。

data =
  sample
  |> Enum.map(fn [height, weight] -> %{height: height, weight: weight} end)

Vl.new()
|> Vl.data_from_values(data)
|> Vl.mark(:point)
|> Vl.encode_field(:x, "height", type: :quantitative, title: "身長", scale: [zero: false])
|> Vl.encode_field(:y, "weight", type: :quantitative, title: "体重", scale: [zero: false])

image.png

1-3. データのテンソル化

sample = Nx.tensor(sample, names: [:x, :y])
x = sample[y: 0]
y = sample[y: 1]

素データをテンソル化し、身長配列を x、体重配列を y とします。x と y を表示します。

x
#Nx.Tensor<
  f32[x: 5]
  [166.0, 176.0, 171.0, 173.0, 169.0]
>


y
#Nx.Tensor<
  f32[x: 5]
  [166.0, 176.0, 171.0, 173.0, 169.0]
>

1-4. データの加工

x = Nx.subtract(x, Nx.mean(x))
y = Nx.subtract(y, Nx.mean(y))

勾配降下法 では扱う数値の絶対値は小さいほうが望ましいとされます。できれば1以下。ここでは身長や体重から、それぞれの平均値を引いて小さくします。

現在のデータの散布図をみます。

Vl.new()
|> Vl.data_from_values(height: Nx.to_flat_list(x), weight: Nx.to_flat_list(y))
|> Vl.mark(:point)
|> Vl.encode_field(:x, "height", type: :quantitative, title: "身長", scale: [zero: false])
|> Vl.encode_field(:y, "weight", type: :quantitative, title: "体重", scale: [zero: false])

image.png

2. 機械学習

defmodule LinearRegression do
  import Nx.Defn

  defn pred({w, b}, x) do
    w * x + b 
  end

  defn mse(yp, y) do
    (yp - y)
    |> Nx.power(2)
    |> Nx.mean()
  end

  defn loss(params, x, y) do
    yp = pred(params, x)
    mse(yp, y)
  end

  defn update({w, b} = params, x, y, lr) do
    {grad_w, grad_b} = grad(params, &loss(&1, x, y))
    { w - grad_w * lr,  b - grad_b * lr }
  end

  defn init_params do
    { Nx.tensor(1.0), Nx.tensor(1.0)}
  end
end

機械学習の基本的な流れは、以下のような構成要素によって定式化されています。以下にそれぞれ詳細に見ていきます。

  • 2-1. 予測関数
  • 2-2. 損失関数
  • 2-3. 損失計算
  • 2-4. 勾配計算とパラメータ更新

2-1. 予測関数 - pred

予測値を yp とすると、予測関数は以下のような1次式で表されます。傾きが w で、切片が b です。これが求めるべきパラメータとなります。

yp = w * x + b

2-2. 損失関数 - mse

損失関数は、予測値 yp と正解値 y との間の距離の平均値として定義されます。つまり 平均二乗誤差(MSE) です。

  defn mse(yp, y) do
    (yp - y)
    |> Nx.power(2)
    |> Nx.mean()
  end

2-3. 損失計算 - loss

損失計算は、予測関数と損失関数の合成関数として求められます。わかり易いように少し書き換えてみます。

  defn loss(params, x, y) do
    mse( pred(params, x), y )
  end

ここで loss 関数は変数params {w. b} の1次式としてみます。x と y は固定値です

2-4. 勾配計算とパラメータ更新 - update

勾配計算には Nx の自動微分を使います。ここがキモです。変数 params に関する loss 関数の偏微分 であり、勾配ベクトル {grad_w, grad_b} が求まります。

  defn update({w, b} = params, x, y, lr) do
    {grad_w, grad_b} = grad(params, &loss(&1, x, y))
    { w - grad_w * lr,  b - grad_b * lr }
  end

lr学習率 と呼ばれる小さな値です。勾配ベクトルに掛けて、パラメータを少しづつズラシテいくのに使われます。理論的な値といううよりは経験的な値です。

3. 訓練

訓練データ(x) と正解ラベル(y) を用いて訓練を行います。訓練はパラメータを修正しながら行う繰り返し計算となります。epochs が繰り返し数となります。lr学習率 です。

epochs = 500
lr = 0.001
init_params = LinearRegression.init_params()
for _ <- 1..epochs, reduce: init_params do
  acc -> LinearRegression.update(acc, x, y, lr)
end

ここで for文の :reduceオプション で update の結果を集約しています。最終的な acc = {w, b} の値が以下のように求まります。

{#Nx.Tensor<
   f32
   1.8206826448440552
 >,
 #Nx.Tensor<
   f32
   0.3675098121166229
 >}

4. 結果評価

4-1. 学習直線

繰り返しの中で、損失の値を計算し、ログに蓄えて、VegaLite でプロットしたいと思います。
そのために LinearRegression モジュールに loss_update 関数を加えます。defn でなく def であることに注意してください。defn はテンソルの数値計算専用でいろいろ制限があるようです。

defmodule LinearRegression do
  import Nx.Defn

  defn pred({w, b}, x) do
    w * x + b 
  end

  defn mse(yp, y) do
    (yp - y)
    |> Nx.power(2)
    |> Nx.mean()
  end

  defn loss(params, x, y) do
    yp = pred(params, x)
    mse(yp, y)
  end

  defn update({w, b} = params, x, y, lr) do
    {grad_w, grad_b} = grad(params, &loss(&1, x, y))
    { w - grad_w * lr,  b - grad_b * lr }
  end

  defn init_params do
    { Nx.tensor(1.0), Nx.tensor(1.0)}
  end

  def loss_update({lvs, w, b}, x, y, lr) do
    lv = LinearRegression.loss({w, b}, x, y)
    {w, b} = LinearRegression.update({w, b}, x, y, lr)
    {[Nx.to_number(lv)|lvs], w, b}
  end
end

繰り返し計算も少し修正します。最後の値を acc に保持し、損失のログリストを lvs に取り出します。

epochs = 500
lr = 0.001
{w, b} = LinearRegression.init_params()
acc =
  for _ <- 1..epochs, reduce: {[], w, b} do
    acc -> LinearRegression.loss_update(acc, x, y, lr)
  end

{lvs, w, b} = acc

acc の内容は以下のようになります。

{[4.67464542388916, 4.6751885414123535, 4.6757354736328125, 4.676283836364746, 4.676833629608154,
  4.67738676071167, 4.677942276000977, 4.678499221801758, 4.6790595054626465, 4.679620265960693,
  4.6801838874816895, 4.680750846862793, 4.681319236755371, 4.68189001083374, 4.6824631690979,
  4.683038234710693, 4.683617115020752, 4.684196472167969, 4.684778213500977, 4.685363292694092,
  4.68595027923584, 4.686540603637695, 4.687131881713867, 4.687726020812988, 4.688323020935059,
  4.688921928405762, 4.689522743225098, 4.690126895904541, 4.690732479095459, 4.691341400146484,
  4.691952705383301, 4.692565441131592, 4.693182468414307, 4.693800926208496, 4.694421291351318,
  4.69504451751709, 4.6956706047058105, 4.696299076080322, 4.696929931640625, 4.697564125061035,
  4.69819974899292, 4.698838233947754, 4.699479579925537, 4.700122833251953, 4.700769424438477,
  4.701418399810791, 4.7020697593688965, 4.702723026275635, 4.703380107879639, ...],
 #Nx.Tensor<
   f32
   1.8206826448440552
 >,
 #Nx.Tensor<
   f32
   0.3675098121166229
 >}

それでは lvs のグラフを描きます。lvs を reverse しているところに注意してください。

Vl.new(width: 600, height: 400, title: "学習曲線")
|> Vl.data_from_values(x: 1..500, y: Enum.reverse(lvs))
|> Vl.mark(:point)
|> Vl.encode_field(:x, "x", type: :quantitative, title: "繰り返し数", scale: [zero: false])
|> Vl.encode_field(:y, "y", type: :quantitative, title: "損失", scale: [zero: false])

損失は 13 強から始まって、繰り返し回数が進むに従ってゼロに近づいていきます。勾配降下法による損失のグラフがきれいに右下下がりになっていることは学習がうまくいっていることを示します。この曲線は 学習曲線 と呼ばれます。ちなみに 5 を切ったところから進捗が鈍ります。これはそもそも最初に与えられた sample データにバラツキがあるので、これが限界ということなのでしょう。

image.png

4-2. 散布図と相関直線

直線の傾きと切片 {w, b} が求められたので、相関直線を描くことができます。 x の値を任意に2つとって、y の予測値を求め、2点を結ぶ直線を引きます。これにもともとのデータ(加工後)の散布図を重ね合わせます。VegaLite の Layer を使えば簡単に重ね合わせられます。散布図は赤色に変更しています。

x0 = -5
x9 = 5
xl = [ x0, x9]
yl = [Nx.to_number(LinearRegression.pred({w, b}, x0)),
      Nx.to_number(LinearRegression.pred({w, b}, x9))] 

Vl.new(width: 600, height: 400, title: "散布図と相関直線(加工後)")
|> Vl.layers([
  Vl.new()
  |> Vl.data_from_values(x: xl, y: yl)
  |> Vl.mark(:line)
  |> Vl.encode_field(:x, "x", type: :quantitative, title: "身長", scale: [zero: false])
  |> Vl.encode_field(:y, "y", type: :quantitative, title: "体重", scale: [zero: false]),
  Vl.new()
  |> Vl.data_from_values(height: Nx.to_flat_list(x), weight: Nx.to_flat_list(y))
  |> Vl.mark(:point)
  |> Vl.encode_field(:x, "height", type: :quantitative)
  |> Vl.encode_field(:y, "weight", type: :quantitative)
  |> Vl.encode(:color, value: "#db646f")
])

いい具合に相関直線が引かれているのがわかります。

image.png

今回は以上です。

6
1
6

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?