LoginSignup
11
6

More than 1 year has passed since last update.

Axon + LiveViewでMNIST手書き文字識別

Last updated at Posted at 2022-02-18

はじめに

今Elixirでは以下のような機械学習関連のライブラリの開発が活発で、先月Nxが0.1になりました

  • 行列演算ライブラリ Nx
  • 行列演算高速化ライブラリ Exla
  • 機械学習用データセット取得ライブラリ SciData
  • Pandasライクなデータフローライブラリ Explorer
  • KerasライクなDLフレームワーク Axon
  • ONNXデータを読み込む axon_onnx (開発中)
  • Elixir製jupyter notebook LiveBook

今回はNx,Exla,SciData,Axonを使用してMNISTの手書き文字認識を行うアプリケーションを作成していきます
MNISTの手書き文字認識はディープラーニングのチュートリアルでよく使用される課題です
今回はアプリケーションとして作成するのでCanvasで描画した数字をLiveView経由でAxonに渡して文字を識別するようにします

DeepLeaningアプリケーション作成の流れ

だいたい以下のような流れになります

  • 解決するタスクを決める -> 数字の手書き文字認識
  • 使用するネットワークを構築 -> Axonのサンプルから 全結合 |> dropout |> 全結合
  • 学習データを集める -> SciDataからMNISTデータを取得
  • 学習を行う -> Axon.Loop.trainer
  • 学習結果とネットワークを書き出す -> erlang DETSに書き出す
  • アプリケーションで予測を行う -> erlang DETSに保存したネットワークと学習結果を使用

プロジェクト作成

作成していきます

mix phx.new live_axon --no-ecto
cd live_axon
mix.exs
defmodule LiveAxon.MixProject do
  use Mix.Project
  ...
  defp deps do
    [
      ...
      {:plug_cowboy, "~> 2.5"},
      {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"},
      {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
      {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
      {:scidata, "~> 0.1.3"}
    ]
  end
end

学習コードの実装

以下を参考にしていきます

https://github.com/elixir-nx/axon/blob/main/examples/vision/mnist.exs
タスクの概要やデータの変換等機械学習に関しての詳細はこちらが大変参考になります
本記事は軽くコメントをつけるまでに留めます

priv/train/mnist.exs
defmodule Mnist do
  # exlaで使うデバイスを指定 rcomはamd系 gpuがなかったら最後のhost(cpu)になる
  EXLA.set_preferred_defn_options([:tpu, :cuda, :rocm, :host])

  # 必須
  require Axon

  # mnistの画像データをnxで使えるように変換
  defp transform_images({bin, type, shape}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.reshape({elem(shape, 0), 784})
    |> Nx.divide(255.0)
    |> Nx.to_batched_list(32)
    # Test split
    |> Enum.split(1750)
  end

  # mnistのラベルデータをnxで使えるように変換
  defp transform_labels({bin, type, _}) do
    bin
    |> Nx.from_binary(type)
    |> Nx.new_axis(-1)
    |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
    |> Nx.to_batched_list(32)
    # Test split
    |> Enum.split(1750)
  end

  # ニューラルネットワークの構築 
  # input_shape 28x28
  # dense1 hidden_layer 128
  # dense2 output 10 
  defp build_model(input_shape) do
    Axon.input(input_shape)
    |> Axon.dense(128, activation: :relu)
    |> Axon.dropout()
    |> Axon.dense(10, activation: :softmax)
  end

  defp train_model(model, train_images, train_labels, epochs) do
    model
    # 損失関数と活性化関数を指定
    |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005))
    # 1epochごとに計算する指標を正解率に指定
    |> Axon.Loop.metric(:accuracy, "Accuracy")
    |> Axon.Loop.run(Stream.zip(train_images, train_labels), epochs: epochs, compiler: EXLA)
  end

  defp test_model(model, model_state, test_images, test_labels) do
    model
    |> Axon.Loop.evaluator(model_state)
    |> Axon.Loop.metric(:accuracy, "Accuracy")
    |> Axon.Loop.run(Stream.zip(test_images, test_labels), compiler: EXLA)
  end

  def run do
    # scidataから画像とラベルのバイナリデータをダウンロード
    {images, labels} = Scidata.MNIST.download()

    # Nx用に変換して分割
    {train_images, test_images} = transform_images(images)
    {train_labels, test_labels} = transform_labels(labels)

    # モデルを構築してネットワーク構成を表示
    model = build_model({nil, 784}) |> IO.inspect()

    IO.write("\n\n Training Model \n\n")

    # epoch5で学習を実行し最終的なパラメーターを受け取る
    model_state =
      model
      |> train_model(train_images, train_labels, 5)


    IO.write("\n\n Testing Model \n\n")
    # テストデータで検証
    model
    |> test_model(model_state, test_images, test_labels)
    |> IO.inspect()
    IO.write("\n\n")

    {model, model_state}
  end

  # 上記を実行して detsにモデルと学習済み重みを保存
  def generate_trained_network do
    {model, weight} = Mnist.run()
    :dets.open_file('weight', type: :bag, file: 'weight.dets')
    :dets.insert('weight',{1,{model,weight}}) 
    :dets.sync('weight')
    :dets.stop    
  end
end
Mnist.generate_trained_network()

detsはets(Erlang Term Storage)というインメモリデータベースがあり、それをファイルに保存して永続化できる方をdetsといいます
Elixirのオブジェクトをread/writeできて永続化もできるのでAxonのモデルと学習済み重みを保存するのに大変都合がいいです。

学習コードのタスク化

タスクとして登録しましょう

mix.exs
defmodule LiveAxon.MixProject do
  use Mix.Project
  ...
  defp aliases do
    [
      setup: ["deps.get"],
      train: ["run priv/train/mnist.exs"], # 追加
      "assets.deploy": ["esbuild default --minify", "phx.digest"]
    ]
  end
end
mix train
# 以下ログ出力
--------------------------------------------------------------
                            Model
==============================================================
 Layer                                Shape        Parameters
==============================================================
 input_0 ( input )                    {nil, 784}   0
 dense_0 ( dense[ "input_0" ] )       {nil, 128}   100480
 relu_0 ( relu[ "dense_0" ] )         {nil, 128}   0
 dropout_0 ( dropout[ "relu_0" ] )    {nil, 128}   0
 dense_1 ( dense[ "dropout_0" ] )     {nil, 10}    1290
 softmax_0 ( softmax[ "dense_1" ] )   {nil, 10}    0
--------------------------------------------------------------

 Training Model 

[warn] Failed to get CPU frequency: 0 Hz
[info] XLA service 0x111e05fa0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
[info]   StreamExecutor device (0): Host, Default Version
Epoch: 0, Batch: 1750, Accuracy: 0.9750814 loss: 0.4228649
Epoch: 1, Batch: 1750, Accuracy: 0.9819489 loss: 0.3669927
Epoch: 2, Batch: 1750, Accuracy: 0.9837040 loss: 0.3398788
Epoch: 3, Batch: 1750, Accuracy: 0.9843555 loss: 0.3236680
Epoch: 4, Batch: 1750, Accuracy: 0.9850162 loss: 0.3118583

 Testing Model 

Batch: 125, Accuracy: 0.9946750%{
  0 => %{
    "Accuracy" => #Nx.Tensor<
      f32
      0.9946749806404114
    >
  }
}

Apple M1 Max で大体3分くらいで完了しました
学習が完了するとプロジェクトのトップにweight.detsというファイルが作成されているのでこちらを使用して文字識別を行っていきます

LiveView + Canvasで手書き文字画像生成

ページ作成

最初にページのみ作成します

lib/live_axon_web/live/page_live.ex
defmodule LiveAxonWeb.PageLive do
  use LiveAxonWeb, :live_view

  @impl true
  def mount(_,_,socket) do
    {:ok, socket}
  end
end
lib/live_axon_web/live/page_live.html.heex
<div class="row">
  <div class="column column-25">
    <h1>Actions</h1>
    <dl>
      <dt><button class="button">Predict</button></dt>
      <dt><button class="button">Clear</button></dt>
    </dl>
  </div>
  <div class="column">
    canvas
  </div>
</div>
lib/live_axon_web/router.ex
defmodule LiveAxonWeb.Router do
  use LiveAxonWeb, :router

  scope "/", LiveAxonWeb do
    pipe_through :browser

    get "/", PageController, :index # 消す 
    live "/", PageLive 
  end
end

Canvas数字描画実装

いわゆるお絵かきツールですね
こちらを参考に機能を削って実装します

assets/js/hooks.js
let Hooks = {};
Hooks.Mnist = {
  mounted() {
    let canvas = document.getElementById("canvas");
    let ctx = canvas.getContext("2d");
    let clickFlg = 0;
    canvas.width = 300
    canvas.height = 300

    const draw = (x, y) => {
      ctx.lineWidth = 10;
      ctx.strokeStyle = 'rgba(255, 255, 255, 1)';
      if (clickFlg == "1") {
        clickFlg = "2";
        ctx.beginPath();
        ctx.lineCap = "round";
        ctx.moveTo(x, y);
      } else {
        ctx.lineTo(x, y);
      }
      ctx.stroke();
    };

    const setBgColor = () => {
      ctx.fillRect(0,0,300,300);
    }    
    setBgColor();

    canvas.addEventListener("mousedown",() => {
      clickFlg = 1;
    })
    canvas.addEventListener("mouseup", () => {
      clickFlg = 0;
    })
    canvas.addEventListener("mousemove",(e) => {
      if(!clickFlg) return false;
      draw(e.offsetX, e.offsetY);
    }); 
  }
}
export default Hooks;

今回はhooksは1つなのでなくてもいいのですが、複数のhooksを作成する場合に連想配列を作成してそこの1つの値としてhookを1つ追加します
mounted関数内に記述すると、LiveViewのページでmount関数実行後に実行されます

hooksを作成したらLiveSocketのパラメーターにhooksで追加します

assets/js/app.js
import Hooks from "./hooks";
let liveSocket = new LiveSocket("/live", Socket, {hooks: Hooks, params: {_csrf_token: csrfToken}})

LiveViewページでhooksの読み込みとcanvas要素を追加します
phx-hookを指定して読み込む要素はidが必要なので注意してください
またcanvasの変更を検知してLiveViewが再描画を行わないようにphx-update="ignore"を指定します

lib/live_axon_web/live/page_live.html.heex
<div id="mnist" phx-hook="Mnist" class="row">
  <div class="column column-25">
  ...
  </div>
  <div class="column">
    <canvas id="canvas" phx-update="ignore"></canvas>
  </div>
</div>

描画初期化ボタン

描画を初期化する機能をLiveViewとhooksで実装します
push_eventでJS hooksの関数を実行します

lib/live_axon_web/live/page_live.ex
defmodule LiveAxonWeb.PageLive do
  use LiveAxonWeb, :live_view
  ...
  @impl true
  def handle_event("clear", _params, socket) do
    {:noreply, push_event(socket, "clear", %{})}
  end
end
assets/js/hooks.js
let Hooks = {};
Hooks.Mnist = {
  mounted() {
    ...
    this.handleEvent("clear", () => {
      ctx.clearRect(0, 0, canvas.width, canvas.height);
      setBgColor();
    });
  }
}
export default Hooks;

画像生成

Predictボタンにクリックイベントとリサイズ用のcanvasを追加します

lib/live_axon_web/live/page_live.html.heex
<div id="mnist" phx-hook="Mnist" class="row">
  <div class="column column-25">
    <h1>Actions</h1>
    <dl>
      <dt><button class="button" phx-click="predict">Predict</button></dt>
      <dt><button class="button" phx-click="clear">Clear</button></dt>
    </dl>
  </div>
  <div class="column">
    <canvas id="canvas" phx-update="ignore"></canvas>
    <canvas id="canvas2" style="visibility:hidden"></canvas>
  </div>
</div>
lib/live_axon_web/live/page_live.ex
defmodule LiveAxonWeb.PageLive do
  use LiveAxonWeb, :live_view
  ...
  @impl true
  def handle_event("predict", _params, socket) do
    {:noreply, push_event(socket, "predict", %{})}
  end
end

リサイズ出力用canvasを作成し
MNISTと同じサイズの28x28にリサイズしてImageDataをpushEventでLiveView側に送信します。

assets/js/hooks.js
let Hooks = {};
Hooks.Mnist = {
  mounted() {
    ...
    let canvas2 = document.getElementById("canvas2");
    let ctx2 = canvas2.getContext("2d");
    canvas2.width = 28;
    canvas2.height = 28;

    this.handleEvent("predict", () => {
      ctx2.drawImage(canvas, 0, 0, 300, 300, 0, 0, 28, 28);
      let data = ctx2.getImageData(0, 0, 28, 28);
      ctx2.fillRect(0, 0, 28, 28);
      this.pushEvent("predict_axon", data);
    });     

  }
}
export default Hooks;

推論

predict_axonを実装します

データの加工

JS側から送信されたImageDataは以下のようなrgbaのUint8ClampedArrayを返すのでデータを加工します

[
  {"864714", 61},
  {"858981", 25},
  {"753417", 99},
  {"985846", 59},
  {"993186", 117},
  {"486748", 234},
  {"149540", 231},
  {"948539", 255},
  ...
]
lib/live_axon_web/live/page_live.ex
defmodule LiveAxonWeb.PageLive do
  use LiveAxonWeb, :live_view

  @impl true
  def handle_event("predict_axon", %{"data" => data}, socket) do
    ans =
      data
      |> convert_image_data_to_tensor()
      |> convert_mnist_predict_data()

    {:noreply, socket}
  end

  def convert_image_data_to_tensor(data) do
    data
    |> Map.to_list()
    |> Enum.map(fn {k, v} -> {String.to_integer(k), v} end)
    |> Enum.sort()
    |> Enum.map(fn {_k, v} ->  v end)
    |> Nx.tensor()
  end

  def convert_mnist_predict_data(pixel) do
    {row} = Nx.shape(pixel)
    pixel
    |> Nx.reshape({div(row, 4), 4})
    |> Nx.slice_along_axis(0, 3, axis: 1)
    |> Nx.mean(axes: [-1])
    |> Nx.round()
    |> Nx.reshape({1, 784})
    |> tap(& &1 |> Nx.reshape({28,28})|> Nx.to_heatmap() |> IO.inspect())
    |> Nx.divide(255.0)
  end
end

convert_image_data_to_tensorは以下のような処理を行っています

Map
|> {key(string), val(integer)}のタプルリスト
|> {key(integer), val(integer)}のタプルリスト
|> タプルのkeyでsort
|> タプルのvalのみ抽出
|> Nx.tensor化

onvert_mnist_predict_dataは以下のような処理を行っています

Nx.shape(pixel) #要素数を取得

pixel
|> [{r,g,b,a},....]の形式に変換
|> [{r,g,b},....]の形式に変換
|> [avg,....]の形式に変換
|> 平均の小数点以下を丸める
|> {784}をAxon用に{1,784}に変換
|> 28x28にしてheatmapにして表示
|> 255で割って0~1.0の間に変換

Axonで推論

学習コードで作成したモデルと重みデータをdetsから読み込んでAxonで推論を実行します
答え表示用にansをassignし、clear eventでも初期化するようにします
predictの結果はNx.Tensorなので|> Nx.argmax() |> Nx.to_number()で一番確率が高いindexを取得して数値に変換しています

lib/live_axon_web/live/page_live.ex
defmodule LiveAxonWeb.PageLive do
  use LiveAxonWeb, :live_view
  require Axon # requireしないと使えない

  @impl true
  def mount(_,_,socket) do
    {:ok, assign(socket, :ans, nil)} 
  end

  @impl true
  def handle_event("clear", _params, socket) do
    {
      :noreply,
      socket
      |> assign(:ans, nil)
      |> push_event("clear", %{})
    }
  end
  ...

  @impl true
  def handle_event("predict_axon", %{"data" => data}, socket) do
    ans =
      data
      |> convert_image_data_to_tensor()
      |> convert_mnist_predict_data()
      |> predict()

    {:noreply, assign(socket, :ans, ans)}
  end

  ...
  def predict(pixel) do
    {:ok, params} = :dets.open_file('weight.dets')
    [{1,{model,weight}}] = :dets.lookup(params,1)
    Axon.predict(model,weight, pixel) |> Nx.argmax() |> Nx.to_number()
  end
end

推論結果が出たらansを表示するようにします

lib/live_axon_web/live/page_live.html.heex
<div id="mnist" phx-hook="Mnist" class="row">
  <div class="column column-25">
    <h1>Actions</h1>
    <dl>
      <dt><button class="button" phx-click="predict">Predict</button></dt>
      <dt><button class="button" phx-click="clear">Clear</button></dt>
    </dl>
    <%= if @ans do %>
      <h1>Answer:<%= @ans %></h1>
    <% end %>
  </div>
  <div class="column">
    <canvas id="canvas" phx-update="ignore"></canvas>
    <canvas id="canvas2" style="visibility:hidden"></canvas>
  </div>
</div>

デモ

Image from Gyazo

最後に

Axon,Nx,SciData,LiveView,Canvas,DETSを使ってMNIST手書き文字認識のアプリケーションを作成できました!
学習結果とモデルをDETSに保存させればResNetで画像分類などを簡単にweb applicationに組み込むことができますし
ONNXも読み込めれば他の学習済みネットワークを使用できたりして夢が広がります!

本記事は以上になりますありがとうございました

code

参考ページ

11
6
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
6