Elixir Axon で排他的論理和(XOR)の機械学習を行います。以下の公式ドキュメントに従います。
Modeling XOR with a neural network
排他的論理和(XOR) とは片方が1の時だけ結果が1となる論理演算です。y = x1 XOR x2
x1 | x2 | y |
---|---|---|
0 | 0 | 0 |
1 | 0 | 1 |
0 | 1 | 1 |
1 | 1 | 0 |
Elixir Axon 機械学習記事のまとめ
問題 | 記事 | 活性化関数 | 出力層 | 損失関数 |
---|---|---|---|---|
単回帰 | Elixir Nx で学ぶ | - | - | MSE |
2値分類 | 排他的論理和 | tanh | sigmoid | binary_cross_entropy |
多値分類 | 手書き数字認識 | relu | softmax | categorical_cross_entropy |
関連記事
Elixir Nx の基礎 (Livebook) - Qiita
Livebook 事始め (Elixir) - Qiita
【機械学習】誤差逆伝播法のコンパクトな説明 - Qiita
Livebook の起動は以下のサイトを参考にを行いました。
Google Colaboratory 上で Elixir Livebook を動かす(GPUを無料で使う)
まず初めに必要なライブラリをインストールしておきます。
Mix.install([
{:axon, github: "elixir-nx/axon"},
{:nx, "~> 0.4"},
{:exla, "~> 0.4"},
{:kino_vega_lite, "~> 0.1.3"}
# {:table_rex, "~> 3.1.1"}
])
Nx.Defn.default_options(compiler: EXLA)
alias VegaLite, as: Vl
1. Model
x1_input = Axon.input("x1", shape: {nil, 1})
x2_input = Axon.input("x2", shape: {nil, 1})
model =
x1_input # {nil, 1}
|> Axon.concatenate(x2_input) # {nil, 2}
|> Axon.dense(8, activation: :tanh) # {nil, 8}
|> Axon.dense(1, activation: :sigmoid) # {nil, 1}
教師あり機械学習は、数値を予測する「回帰」というモデルと、どのグループに属するかを予測する「分類」というモデルに大別されます。その中でも分類は、2値分類 と 多値分類 に分けることができます。XOR の問題は、 0 か 1 か、なので 2値分類 になります。
input(name, opts \ [])
ネットワークに input 層 を加える。
name で input 層の名前を与えなければならず、これは複数の input 層がある場合にどの input 層に対する data かを識別するために使われる。
shape は {nil, 1} ですが、 nil は動的に決定するバッチのサイズを表現しています。
concatenate(x, y, opts)
2つの input 層 を結合します。
活性化関数
- Tanh は「-1から1の間にマッピングしなおして」出力する。
- Sigmoid は「0から1の間にマッピングしなおして」出力する。
Tanh 活性化関数の特性
シグモイド関数 の微分係数(Derivative: 導関数の出力値)の最大値は0.25と小さいため、勾配が小さくなりがちで学習に時間がかかるという問題があった。そのため学習がより高速化するように、最大値が 1.0 となる「tanh関数」がよく採用されるようになった。
さらに最近のディープニューラルネットワークでは、「ReLU」がよく使われるようになっている。というのも、シグモイド関数 や tanh関数 では 勾配消失問題 を解決できなかったが、ReLU では(図2を見ると分かるように「入力が0.0より大きいと、常に1.0が出力される」ことによって) 勾配消失問題 を解決できるようになったからである。
[活性化関数]tanh関数(Hyperbolic tangent function: 双曲線正接関数)とは?
Model の概要を以下のようにしてグラフ出力することができます。
input = Nx.template({2, 1}, :u8)
Axon.Display.as_graph(model, input)
2. 訓練データ
batch_size = 32
data =
Stream.repeatedly(fn ->
x1 = Nx.random_uniform({batch_size, 1}, 0, 2)
x2 = Nx.random_uniform({batch_size, 1}, 0, 2)
y = Nx.logical_xor(x1, x2)
{%{"x1" => x1, "x2" => x2}, y}
end)
random_uniform(tensor_or_shape, min, max, opts \ []) <--- deprecated
一様分布に従った乱数を返す。 [min, max)
- min と max が整数ならば {:s, 64} のテンソルを返す
- それ以外ならば {:f, 64} のテンソルを返す。
- ただし :type オプションで明示的に指定できる。
バッチの最初の配列だけ表示してみる。
Enum.at(data, 0)
最初の5個と最後だけ表示し、それ以外は省略します。
{%{
"x1" => #Nx.Tensor<
s64[32][1]
[
[1],
[0],
[1],
[1],
[1],
---
[0]
]
>,
"x2" => #Nx.Tensor<
s64[32][1]
[
[0],
[1],
[1],
[1],
[1],
---
[0]
]
>
},
#Nx.Tensor<
u8[32][1]
[
[1],
[1],
[0],
[0],
[0],
---
[0]
]
>}
3. 訓練
epochs = 10
params =
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
trainer(model, loss, optimizer, loss_scale \ :identity, opts \ [])
Model と 損失関数 とオプティマイザから training loop を作り出します。
ここでは損失関数として binary_cross_entropy が指定されています。
損失関数 交差エントロピー関数 (cross_entropy) Axon.Losses
分類モデルでは損失関数に交差エントロピー関数が使われることが多いようです。2値分類では :binary_cross_entropy、多値分類では :categorical_cross_entropy です。
交差エントロピー誤差をわかりやすく説明してみる
なぜ分類問題の目的関数には交差エントロピーが使われるのか
run(loop, data, init_state \ %{}, opts \ [])
訓練データと解答(labels)を与えてloopを走らせます。
data は Enumerable かまたは Stream で、繰り返しの都度、data batch を提供してくれるもの。
オプションは以下の通り
- :epochs - ループの最大 epoch 数。デフォルト は 1。
- :iterations - 各 epoch 毎の最大繰り返し数。
- :jit_compile? - デフォルト true。
- :debug - ループの過程をトレースするためのデバッグモード。デフォルト false。
最初に指定していますが、compiler: EXLA が無ければ、とても遅いです。
4. Modelによる予測
Axon.predict(model, params, %{
"x1" => Nx.tensor([[0]]),
"x2" => Nx.tensor([[1]])
})
x1=0, x2=1 で y=0.9679677486419678 なのでokです。
#Nx.Tensor<
f32[1][1]
EXLA.Backend<cuda:0, 0.3838539344.1197867016.113259>
[
[0.9679677486419678]
]
>
Axon.predict(model, params, %{
"x1" => Nx.tensor([[0]]),
"x2" => Nx.tensor([[0]])
})
x1=0, x2=8 で y=0.012395383790135384 なのでokです。
#Nx.Tensor<
f32[1][1]
EXLA.Backend<cuda:0, 0.3838539344.1197867016.113269>
[
[0.012395383790135384]
]
>
私の環境では GPU を使いつつ、vega_lite を使うのは不可能であるようなので、グラフによる結果の可視化は諦めました。
今回は以上です。