はじめに
今Elixirでは以下のような機械学習関連のライブラリの開発が活発で、先月Nxが0.1になりました
- 行列演算ライブラリ Nx
- 行列演算高速化ライブラリ Exla
- 機械学習用データセット取得ライブラリ SciData
- Pandasライクなデータフローライブラリ Explorer
- KerasライクなDLフレームワーク Axon
- ONNXデータを読み込む axon_onnx (開発中)
- Elixir製jupyter notebook LiveBook
今回はNx,Exla,SciData,Axonを使用してMNISTの手書き文字認識を行うアプリケーションを作成していきます
MNISTの手書き文字認識はディープラーニングのチュートリアルでよく使用される課題です
今回はアプリケーションとして作成するのでCanvasで描画した数字をLiveView経由でAxonに渡して文字を識別するようにします
DeepLeaningアプリケーション作成の流れ
だいたい以下のような流れになります
- 解決するタスクを決める -> 数字の手書き文字認識
- 使用するネットワークを構築 -> Axonのサンプルから 全結合 |> dropout |> 全結合
- 学習データを集める -> SciDataからMNISTデータを取得
- 学習を行う -> Axon.Loop.trainer
- 学習結果とネットワークを書き出す -> erlang DETSに書き出す
- アプリケーションで予測を行う -> erlang DETSに保存したネットワークと学習結果を使用
プロジェクト作成
作成していきます
mix phx.new live_axon --no-ecto
cd live_axon
defmodule LiveAxon.MixProject do
use Mix.Project
...
defp deps do
[
...
{:plug_cowboy, "~> 2.5"},
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"},
{:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
{:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
{:scidata, "~> 0.1.3"}
]
end
end
学習コードの実装
以下を参考にしていきます
https://github.com/elixir-nx/axon/blob/main/examples/vision/mnist.exs
タスクの概要やデータの変換等機械学習に関しての詳細はこちらが大変参考になります
本記事は軽くコメントをつけるまでに留めます
defmodule Mnist do
# exlaで使うデバイスを指定 rcomはamd系 gpuがなかったら最後のhost(cpu)になる
EXLA.set_preferred_defn_options([:tpu, :cuda, :rocm, :host])
# 必須
require Axon
# mnistの画像データをnxで使えるように変換
defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 784})
|> Nx.divide(255.0)
|> Nx.to_batched_list(32)
# Test split
|> Enum.split(1750)
end
# mnistのラベルデータをnxで使えるように変換
defp transform_labels({bin, type, _}) do
bin
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
# Test split
|> Enum.split(1750)
end
# ニューラルネットワークの構築
# input_shape 28x28
# dense1 hidden_layer 128
# dense2 output 10
defp build_model(input_shape) do
Axon.input(input_shape)
|> Axon.dense(128, activation: :relu)
|> Axon.dropout()
|> Axon.dense(10, activation: :softmax)
end
defp train_model(model, train_images, train_labels, epochs) do
model
# 損失関数と活性化関数を指定
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005))
# 1epochごとに計算する指標を正解率に指定
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(Stream.zip(train_images, train_labels), epochs: epochs, compiler: EXLA)
end
defp test_model(model, model_state, test_images, test_labels) do
model
|> Axon.Loop.evaluator(model_state)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.run(Stream.zip(test_images, test_labels), compiler: EXLA)
end
def run do
# scidataから画像とラベルのバイナリデータをダウンロード
{images, labels} = Scidata.MNIST.download()
# Nx用に変換して分割
{train_images, test_images} = transform_images(images)
{train_labels, test_labels} = transform_labels(labels)
# モデルを構築してネットワーク構成を表示
model = build_model({nil, 784}) |> IO.inspect()
IO.write("\n\n Training Model \n\n")
# epoch5で学習を実行し最終的なパラメーターを受け取る
model_state =
model
|> train_model(train_images, train_labels, 5)
IO.write("\n\n Testing Model \n\n")
# テストデータで検証
model
|> test_model(model_state, test_images, test_labels)
|> IO.inspect()
IO.write("\n\n")
{model, model_state}
end
# 上記を実行して detsにモデルと学習済み重みを保存
def generate_trained_network do
{model, weight} = Mnist.run()
:dets.open_file('weight', type: :bag, file: 'weight.dets')
:dets.insert('weight',{1,{model,weight}})
:dets.sync('weight')
:dets.stop
end
end
Mnist.generate_trained_network()
detsはets(Erlang Term Storage)というインメモリデータベースがあり、それをファイルに保存して永続化できる方をdetsといいます
Elixirのオブジェクトをread/writeできて永続化もできるのでAxonのモデルと学習済み重みを保存するのに大変都合がいいです。
学習コードのタスク化
タスクとして登録しましょう
defmodule LiveAxon.MixProject do
use Mix.Project
...
defp aliases do
[
setup: ["deps.get"],
train: ["run priv/train/mnist.exs"], # 追加
"assets.deploy": ["esbuild default --minify", "phx.digest"]
]
end
end
mix train
# 以下ログ出力
--------------------------------------------------------------
Model
==============================================================
Layer Shape Parameters
==============================================================
input_0 ( input ) {nil, 784} 0
dense_0 ( dense[ "input_0" ] ) {nil, 128} 100480
relu_0 ( relu[ "dense_0" ] ) {nil, 128} 0
dropout_0 ( dropout[ "relu_0" ] ) {nil, 128} 0
dense_1 ( dense[ "dropout_0" ] ) {nil, 10} 1290
softmax_0 ( softmax[ "dense_1" ] ) {nil, 10} 0
--------------------------------------------------------------
Training Model
[warn] Failed to get CPU frequency: 0 Hz
[info] XLA service 0x111e05fa0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
[info] StreamExecutor device (0): Host, Default Version
Epoch: 0, Batch: 1750, Accuracy: 0.9750814 loss: 0.4228649
Epoch: 1, Batch: 1750, Accuracy: 0.9819489 loss: 0.3669927
Epoch: 2, Batch: 1750, Accuracy: 0.9837040 loss: 0.3398788
Epoch: 3, Batch: 1750, Accuracy: 0.9843555 loss: 0.3236680
Epoch: 4, Batch: 1750, Accuracy: 0.9850162 loss: 0.3118583
Testing Model
Batch: 125, Accuracy: 0.9946750%{
0 => %{
"Accuracy" => #Nx.Tensor<
f32
0.9946749806404114
>
}
}
Apple M1 Max で大体3分くらいで完了しました
学習が完了するとプロジェクトのトップにweight.dets
というファイルが作成されているのでこちらを使用して文字識別を行っていきます
LiveView + Canvasで手書き文字画像生成
ページ作成
最初にページのみ作成します
defmodule LiveAxonWeb.PageLive do
use LiveAxonWeb, :live_view
@impl true
def mount(_,_,socket) do
{:ok, socket}
end
end
<div class="row">
<div class="column column-25">
<h1>Actions</h1>
<dl>
<dt><button class="button">Predict</button></dt>
<dt><button class="button">Clear</button></dt>
</dl>
</div>
<div class="column">
canvas
</div>
</div>
defmodule LiveAxonWeb.Router do
use LiveAxonWeb, :router
scope "/", LiveAxonWeb do
pipe_through :browser
get "/", PageController, :index # 消す
live "/", PageLive
end
end
Canvas数字描画実装
いわゆるお絵かきツールですね
こちらを参考に機能を削って実装します
let Hooks = {};
Hooks.Mnist = {
mounted() {
let canvas = document.getElementById("canvas");
let ctx = canvas.getContext("2d");
let clickFlg = 0;
canvas.width = 300
canvas.height = 300
const draw = (x, y) => {
ctx.lineWidth = 10;
ctx.strokeStyle = 'rgba(255, 255, 255, 1)';
if (clickFlg == "1") {
clickFlg = "2";
ctx.beginPath();
ctx.lineCap = "round";
ctx.moveTo(x, y);
} else {
ctx.lineTo(x, y);
}
ctx.stroke();
};
const setBgColor = () => {
ctx.fillRect(0,0,300,300);
}
setBgColor();
canvas.addEventListener("mousedown",() => {
clickFlg = 1;
})
canvas.addEventListener("mouseup", () => {
clickFlg = 0;
})
canvas.addEventListener("mousemove",(e) => {
if(!clickFlg) return false;
draw(e.offsetX, e.offsetY);
});
}
}
export default Hooks;
今回はhooksは1つなのでなくてもいいのですが、複数のhooksを作成する場合に連想配列を作成してそこの1つの値としてhookを1つ追加します
mounted関数内に記述すると、LiveViewのページでmount関数実行後に実行されます
hooksを作成したらLiveSocketのパラメーターにhooksで追加します
import Hooks from "./hooks";
let liveSocket = new LiveSocket("/live", Socket, {hooks: Hooks, params: {_csrf_token: csrfToken}})
LiveViewページでhooksの読み込みとcanvas要素を追加します
phx-hookを指定して読み込む要素はidが必要なので注意してください
またcanvasの変更を検知してLiveViewが再描画を行わないようにphx-update="ignore"
を指定します
<div id="mnist" phx-hook="Mnist" class="row">
<div class="column column-25">
...
</div>
<div class="column">
<canvas id="canvas" phx-update="ignore"></canvas>
</div>
</div>
描画初期化ボタン
描画を初期化する機能をLiveViewとhooksで実装します
push_eventでJS hooksの関数を実行します
defmodule LiveAxonWeb.PageLive do
use LiveAxonWeb, :live_view
...
@impl true
def handle_event("clear", _params, socket) do
{:noreply, push_event(socket, "clear", %{})}
end
end
let Hooks = {};
Hooks.Mnist = {
mounted() {
...
this.handleEvent("clear", () => {
ctx.clearRect(0, 0, canvas.width, canvas.height);
setBgColor();
});
}
}
export default Hooks;
画像生成
Predictボタンにクリックイベントとリサイズ用のcanvasを追加します
<div id="mnist" phx-hook="Mnist" class="row">
<div class="column column-25">
<h1>Actions</h1>
<dl>
<dt><button class="button" phx-click="predict">Predict</button></dt>
<dt><button class="button" phx-click="clear">Clear</button></dt>
</dl>
</div>
<div class="column">
<canvas id="canvas" phx-update="ignore"></canvas>
<canvas id="canvas2" style="visibility:hidden"></canvas>
</div>
</div>
defmodule LiveAxonWeb.PageLive do
use LiveAxonWeb, :live_view
...
@impl true
def handle_event("predict", _params, socket) do
{:noreply, push_event(socket, "predict", %{})}
end
end
リサイズ出力用canvasを作成し
MNISTと同じサイズの28x28にリサイズしてImageDataをpushEventでLiveView側に送信します。
let Hooks = {};
Hooks.Mnist = {
mounted() {
...
let canvas2 = document.getElementById("canvas2");
let ctx2 = canvas2.getContext("2d");
canvas2.width = 28;
canvas2.height = 28;
this.handleEvent("predict", () => {
ctx2.drawImage(canvas, 0, 0, 300, 300, 0, 0, 28, 28);
let data = ctx2.getImageData(0, 0, 28, 28);
ctx2.fillRect(0, 0, 28, 28);
this.pushEvent("predict_axon", data);
});
}
}
export default Hooks;
推論
predict_axonを実装します
データの加工
JS側から送信されたImageDataは以下のようなrgbaのUint8ClampedArrayを返すのでデータを加工します
[
{"864714", 61},
{"858981", 25},
{"753417", 99},
{"985846", 59},
{"993186", 117},
{"486748", 234},
{"149540", 231},
{"948539", 255},
...
]
defmodule LiveAxonWeb.PageLive do
use LiveAxonWeb, :live_view
@impl true
def handle_event("predict_axon", %{"data" => data}, socket) do
ans =
data
|> convert_image_data_to_tensor()
|> convert_mnist_predict_data()
{:noreply, socket}
end
def convert_image_data_to_tensor(data) do
data
|> Map.to_list()
|> Enum.map(fn {k, v} -> {String.to_integer(k), v} end)
|> Enum.sort()
|> Enum.map(fn {_k, v} -> v end)
|> Nx.tensor()
end
def convert_mnist_predict_data(pixel) do
{row} = Nx.shape(pixel)
pixel
|> Nx.reshape({div(row, 4), 4})
|> Nx.slice_along_axis(0, 3, axis: 1)
|> Nx.mean(axes: [-1])
|> Nx.round()
|> Nx.reshape({1, 784})
|> tap(& &1 |> Nx.reshape({28,28})|> Nx.to_heatmap() |> IO.inspect())
|> Nx.divide(255.0)
end
end
convert_image_data_to_tensorは以下のような処理を行っています
Map
|> {key(string), val(integer)}のタプルリスト
|> {key(integer), val(integer)}のタプルリスト
|> タプルのkeyでsort
|> タプルのvalのみ抽出
|> Nx.tensor化
onvert_mnist_predict_dataは以下のような処理を行っています
Nx.shape(pixel) #要素数を取得
pixel
|> [{r,g,b,a},....]の形式に変換
|> [{r,g,b},....]の形式に変換
|> [avg,....]の形式に変換
|> 平均の小数点以下を丸める
|> {784}をAxon用に{1,784}に変換
|> 28x28にしてheatmapにして表示
|> 255で割って0~1.0の間に変換
Axonで推論
学習コードで作成したモデルと重みデータをdetsから読み込んでAxonで推論を実行します
答え表示用にansをassignし、clear eventでも初期化するようにします
predictの結果はNx.Tensorなので|> Nx.argmax() |> Nx.to_number()
で一番確率が高いindexを取得して数値に変換しています
defmodule LiveAxonWeb.PageLive do
use LiveAxonWeb, :live_view
require Axon # requireしないと使えない
@impl true
def mount(_,_,socket) do
{:ok, assign(socket, :ans, nil)}
end
@impl true
def handle_event("clear", _params, socket) do
{
:noreply,
socket
|> assign(:ans, nil)
|> push_event("clear", %{})
}
end
...
@impl true
def handle_event("predict_axon", %{"data" => data}, socket) do
ans =
data
|> convert_image_data_to_tensor()
|> convert_mnist_predict_data()
|> predict()
{:noreply, assign(socket, :ans, ans)}
end
...
def predict(pixel) do
{:ok, params} = :dets.open_file('weight.dets')
[{1,{model,weight}}] = :dets.lookup(params,1)
Axon.predict(model,weight, pixel) |> Nx.argmax() |> Nx.to_number()
end
end
推論結果が出たらansを表示するようにします
<div id="mnist" phx-hook="Mnist" class="row">
<div class="column column-25">
<h1>Actions</h1>
<dl>
<dt><button class="button" phx-click="predict">Predict</button></dt>
<dt><button class="button" phx-click="clear">Clear</button></dt>
</dl>
<%= if @ans do %>
<h1>Answer:<%= @ans %></h1>
<% end %>
</div>
<div class="column">
<canvas id="canvas" phx-update="ignore"></canvas>
<canvas id="canvas2" style="visibility:hidden"></canvas>
</div>
</div>
デモ
最後に
Axon,Nx,SciData,LiveView,Canvas,DETSを使ってMNIST手書き文字認識のアプリケーションを作成できました!
学習結果とモデルをDETSに保存させればResNetで画像分類などを簡単にweb applicationに組み込むことができますし
ONNXも読み込めれば他の学習済みネットワークを使用できたりして夢が広がります!
本記事は以上になりますありがとうございました
code
参考ページ
- https://qiita.com/piacerex/items/9cf4f328222103458167
- https://www.otwo.jp/blog/canvas-drawing/
- https://developer.mozilla.org/en-US/docs/Web/API/CanvasRenderingContext2D
- https://developer.mozilla.org/en-US/docs/Web/API/Canvas_API
- https://github.com/elixir-nx/nx/tree/main/nx
- https://github.com/elixir-nx/nx/tree/main/exla
- https://github.com/elixir-nx/scidata
- https://github.com/elixir-nx/explorer
- https://github.com/elixir-nx/axon
- https://github.com/elixir-nx/axon_onnx
- https://github.com/livebook-dev/livebook