14
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

MATLAB のニューラルネットをブラウザで動かす: MATLAB > C++ > WebAssembly の自動変換

Last updated at Posted at 2020-08-23

0. はじめに

MATLAB をブラウザ実装したら面白そうなものないかなぁ・・と思っていたら、素晴らしい記事が見つかったのでマネをさせてもらいました。

TensorFlow.jsでMNIST学習済モデルを読み込みブラウザで手書き文字認識をする

この記事では TensorFlow.js でWebブラウザ上で書いた数字が 0 ~ 9 のどれかを予測していますが、この予測部分に MATLAB のニューラルネットを使ってみた、というお話。

やったこと

コードこちらから: GitHub: minoue-xx/handwritten-digit-prediction-on-browser
実行ページはこちら:Github Pages: Hand-written Digit Prediction on Browser

予測精度は怪しいですが、ここではまず動くものを、、と単純なネットワークを学習して GitHub Pages に実装するまでのステップを解説する記事です。

attach:cat

WebAssembly 変換?

MATLAB -> C++ -> WebAssembly の自動変換を使った非線形最適化 on JavaScript では fmincon 関数を使った最適化をブラウザ実装しましたが、同じ方法を使います。

File Exchange に公開されている Generate JavaScript Using MATLAB Coder というツールで、MATLAB Coder を使って MATLAB から WebAssembly に変換して実装します。

image_1.png

基本的には Generate JavaScript Using MATLAB Coder で用意されている 例題: Pass Data to a Library の流れに沿って作業しています。

環境

1. ツールの設定

基本的には MATLAB -> C++ -> WebAssembly の自動変換を使った非線形最適化 on JavaScript と同じステップを踏みます。

File Exchange から Generate JavaScript Using MATLAB Coder をインストール。まず開く Setup.mlx に従って Emscripten Development Kit の最新版をインストールします。ネットワークフォルダだとうまくいかなかったので、ローカルにインストールください。

2. MATLAB Project 作成

Generate JavaScript Using MATLAB Coder では MATLAB Project を使用します。

作業フォルダ(trainingModels/generateWebAssembly)に移動して、以下を実行。出力形式は Dynamic Libeary (dll) です。

proj = webcoder.setup.project("digitprediction","Directory",pwd,"OutputType",'dll');

3. MATLAB 関数作成

非線形最適化を実施する関数 digitPredictionFcn.m を作ります。まずは単純に隠れ層が1層だけの”浅い”ネットワークでやってみます。

  • Step 1: データ読み込み
  • Step 2: 学習
  • Step 3: モデルの MATLAB 関数化
  • Step 4: コード生成用に微修正
attach:cat

学習ステップはこちらを参考にしました: MathWorks Blog: Artificial Neural Networks for Beginners

MATLAB 関数作成詳細

Step 1: データ読み込み

%% サンプルデータ読み込み(Deep Learning Toolbox に入っているデータです)
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;
classNames = categories(YTrain);
numClasses = numel(classNames);
numObservations = numel(YTrain);
whos XTrain
  Name         Size                      Bytes  Class     Attributes

  XTrain      28x28x1x5000            31360000  double              

28x28 のモノクロ画像が 5000 枚入っていますね。

montage(XTrain(:,:,:,1:16))
attach:cat

こんな画像です。以下の通り各ラベル(数字)につき 500 枚の画像が用意されています。

tabulate(YTrain)
  Value    Count   Percent
      0      500     10.00%
      1      500     10.00%
      2      500     10.00%
      3      500     10.00%
      4      500     10.00%
      5      500     10.00%
      6      500     10.00%
      7      500     10.00%
      8      500     10.00%
      9      500     10.00%

Step 2: 学習

この辺は MathWorks Blog: Artificial Neural Networks for Beginners にある通りアプリでもできます。以下はアプリから吐き出した MATLAB コードをほぼそのまま流用。

outputs = dummyvar(YTrain);       % convert label into a dummy variable
outputs = outputs';               % transpose dummy variable
inputs = reshape(XTrain,28*28,5000);

rng(1); % 再現性確保のため乱数初期化

x = inputs;
t = outputs;
trainFcn = 'trainscg';  % Scaled conjugate gradient backpropagation.

% Create a Pattern Recognition Network
hiddenLayerSize = 100;
net = patternnet(hiddenLayerSize, trainFcn);

% Setup Division of Data for Training, Validation, Testing
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;

% Train the Network
[net,tr] = train(net,x,t);

混同行列を書いてみると悪くない結果です。詳細はひとまず問いません。

plotconfusion(t,net(x))
attach:cat

Step 3: モデルの MATLAB 関数化

学習したネットワーク net をコード生成にもっていくために MATLAB 関数に変換します。

genFunction(net, 'digitPredictFcn');
Output
 
MATLAB function generated: digitPredictFcn.m
To view generated function code: edit digitPredictFcn
For examples of using function: help digitPredictFcn
 

こんな感じで digitPredictionFcn.m 、ネットワークの重みも情報も全部書き出した関数ファイルが出来上がります。

Step 4: コード生成用に微調整

残念ながら・・ここでできた関数ではコード生成に対応していない書き方が混ざっていたので、 C++ に自動変換するには少し手を入れる必要がありました。出来上がったものはこちら: GitHub: digitPredictFcn.m

コード生成用に行った微調整詳細(4項目)

Step 4-1: 関数定義、アサーション

function [Y,Xf,Af] = digitPredictFcn(X,~,~) 

と定義されていますが、以下の通りシンプルに1入力1出力の関数にします。

function YY = digitPredictFcn(XX)

また、コード生成用に入力引数のサイズを明記しておく必要があるので、

assert(isa(XX, 'double'));
assert(all( size(XX) == [ 28*28, 1 ]));

を冒頭に追記します。

Step 4-2: インメモリ処理

X = {X}

double 型をセル配列に置き換えるこんな処理はダメみたいです。

% Format Input Arguments
isCellX = iscell(XX);
if ~isCellX
  X = {XX};
end

入力を変数 XX としセル配列にしたものを X に。この後はそのまま。

Step 4-3: cell2mat 関数

予測結果がセル配列 Y で帰ってきて、最後に double 型に変換していますが、cell2mat 関数はコード生成対象外。

Y = cell2mat(Y);

ここは乱暴な気もしますが入力として入ってくる変数サイズは決まっているので、

YY = Y{:};

と置き換えて、YY を出力変数とします。

関数ができたら、digitPredictFcn.m をプロジェクトに追加して、ラベルを UserEntryPoints > Function に設定しておきます。

attach:cat

こんな感じ。

4. JavaScript と WebAssembly の生成

以下のコードで MATLAB Project からビルドします。裏で MATLAB Coder + Emscripten SDK が走ります。

proj = openProject(pwd);
webcoder.build.project(proj);
コードの生成が成功しました:レポートの表示

無事に終了すると、C++ コードが build フォルダに出力されます。さらに、この C++ コードが digitprediction.jsdigitprediction.wasm にコンパイルされて、dist フォルダに出力されます。

5. HTML/JavaScript から呼び出し

さて、ようやく本題。MDN によると

JavaScript typed arrays are array-like objects and provide a mechanism for accessing raw binary data.

とのことで、この JavaScript typed arrays を使って、JavaScript から optimizeposition.wasm とデータをやり取りします。

処理の流れ

  1. JavaScript typed array を作成
  2. typed array の要素数から必要な領域を計算、wasm 側のメモリを確保
  3. 確保した領域に typed array の値をコピー
  4. wasm 側の計算処理を実行
  5. wasm 側のメモリから typed array に値をコピー
  6. 不要になった領域を解放

1-3 の処理をしているのが _arrayToHeap、5 が _heapToArray です。処理の詳細は Guthub: Planeshifter/emscripten-examples の README.md の記述が参考になります。

JavaScriptからの呼び出しコード
getAccryracyScores

    function getAccuracyScores(imageData) {

      let inputs = [];
      let length = 28 * 28; // ピクセルサイズ

      for (let i = 0; i < length * 4; i = i + 4) { // 必要なピクセルだけ取り出します
        inputs.push(imageData.data[i] / 255);
      }
      console.log(inputs); // 確認

      var Inputs = new Float64Array(inputs);
      var Outputs = new Float64Array(10);

      // Move Data to Heap var
      var Inputsbytes = _arrayToHeap(Inputs);
      var Outputsbytes = _arrayToHeap(Outputs);

      // Run Function
      Module._digitprediction_initialize();
      Module._digitPredictFcn(Inputsbytes.byteOffset, Outputsbytes.byteOffset)
      Module._digitprediction_terminate();

      // Copy Data from Heap 
      Outputs = _heapToArray(Outputsbytes, Outputs);
      var outputs = Array.from(Outputs);

      // Free Data from Heap 
      _freeArray(Inputsbytes);
      _freeArray(Outputsbytes);

      // Display Results
      console.log(outputs);
      const score = outputs;
      return score;
    }

6. 結果をみてみよう

ローカルサーバを立てて結果をみてみます。Fetch API は file URI Scheme をサポートしていないため、ファイルに http URI Scheme でアクセスできるようにする必要があるらしい。Generate JavaScript Using MATLAB Coder には関数が用意されてますのでこれを使います。

先ほど .js.wasm が出力されたフォルダ dist に index.html を置きます。MATLAB 上で dist をカレントフォルダにして、

server = webcoder.utilities.DevelopmentServer("Port",8125);
start(server);
Development Server serving directory '.' at locations:
    http://10.0.1.14:8125
    http://localhost:8125
% サーバを落とすときは
% stop(server);
attach:cat

計算はできている模様。ひとまずめでたしめでたし。

7. まとめ

ひとまず MATLAB で書いたネットワークが JavaScript から正しく呼び出せていることが確認できました。

ただ・・実際に Github Pages: Hand-written Digit Prediction on Browser から試してみるとわかる通り、精度はいまいち。

入力パッドから入ってくる画像が、学習に使用した画像と性質が異なるのか・・そもそも 5,000 枚しかなかったのが原因なのか?この辺についてはまた次回取り組んでみます。

14
12
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?