matlab
AWS
nodejs
AngularJS
DeepLearning

スマホで撮った写真をディープラーニングで判定してみる (デプロイ編)

@p_panther さんが「スマホで撮った写真をディープラーニングで判定してみる (学習編)」で作ってくれたディープラーニングによる学習モデルを用いて、画像から何が写っているのか推論するアプリケーションとしてデプロイしてみます。

ディープラーニングによる推論をシステム化するための検討

ディープラーニングによる推論アルゴリズムをシステム化する上で、様々なユーザーの環境からでも使えるよう、Webブラウザで使えるWebアプリケーションにするのが最も良さそうです。WebアプリケーションならPCからでもスマートフォンからでも、タブレットからでも使えますので。

MATLABプログラムをMATLABのサーバー製品であるMATLAB Production Server(MPS)向けにコンパイルし、WebアプリケーションのバックエンドとしてRESTful APIで実行できるようにします。

日本だけではなく海外の人にも試してもらうようなアプリケーションにしたいので、操作方法もかなりシンプルにして、画像をアップロードすれば写っているものを推論して返すだけのアプリケーションにしてみたいと思います。

MATLAB_Deployment.png

Webブラウザ実装

HTML/JavaScript側で実装するものとしては、
- NodeJSへの画像アップロード
- NodeJSからのリターンを受け取りMPSのRESTful APIをcall
- MPSからリターンされたFigureのバイト配列をbase64に変換し、imgタグに貼り付け

ctrl.js(NodeJSへのアップロード部分)
var formData = new FormData();

// Append to varialbe "uploads[]" which is defined in upload-input in index.html
formData.append('uploads[]', file, file.name);

// Post to NodeJS
$.ajax({
    url: '/uploadImage',
    type: 'POST',
    data: formData,
    processData: false,
    contentType: false,
    success: function (data) {
        // Set uploaded file name to scope variable
        $scope.imageFile = 'uploads/' + data;
        // And seamlessly do MPS calculation
        $scope.doImageProcessing($scope.imageFile, $scope.selectedNetwork);
    }
});
ctrl.js(MPS呼び出し部分)
$scope.doImageProcessing = function (imageFile, selectedNetwork) {
    // Set network switch
    $scope.netsw = 1;
    if (selectedNetwork !== 'ReTrained') {
        $scope.netsw = 2;
    }

    //input arguments to MPS
    var inputArg = {
        "nargout": 1, //num_of_requested_outputs
        "rhs": [imageFile, $scope.netsw], //input arguments
        "outputFormat": {
            "mode": "large", // small or large
            "nanInfFormat": "object" // string or object
        }
    };

    //POST to MPS
    var req = {
        method: 'POST',
        url: 'https:://HOST:PORT/myPrediction_MPS/myPrediction_MPS',
        headers: {
            'Accept': 'application/json',
            'Content-Type': 'application/json'
        },
        data: inputArg
    };

    $http(req).
        then(function (response) {
            $scope.results = response.data;
            if ($scope.results.error) {
                // MATLAB error handle

                return;
            } else {
                // Change img source
                var matlabPlotData = $scope.results.lhs[0].mwdata;
                matlabPlotData = $scope.base64ArrayBuffer(matlabPlotData);
                matlabPlotData = 'data:image/png;base64,' + matlabPlotData;
                $("#plotImg").attr("src", matlabPlotData);
            }
        }, function (response) {
            // HTTP error handle
    });
};

// function for converting byte streams to back to images for the web
// From https://gist.github.com/jonleighton/958841
// MIT LICENSE
$scope.base64ArrayBuffer = function (arrayBuffer) {
    var base64 = '';
    var encodings = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/';
    var bytes = new Uint8Array(arrayBuffer);
    var byteLength = bytes.byteLength;
    var byteRemainder = byteLength % 3;
    var mainLength = byteLength - byteRemainder;
    var a, b, c, d;
    var chunk;
    // Main loop deals with bytes in chunks of 3
    for (var i = 0; i < mainLength; i = i + 3) {
        // Combine the three bytes into a single integer
        chunk = (bytes[i] << 16) | (bytes[i + 1] << 8) | bytes[i + 2];
        // Use bitmasks to extract 6-bit segments from the triplet
        a = (chunk & 16515072) >> 18; // 16515072 = (2^6 - 1) << 18
        b = (chunk & 258048) >> 12; // 258048   = (2^6 - 1) << 12
        c = (chunk & 4032) >> 6; // 4032     = (2^6 - 1) << 6
        d = chunk & 63; // 63       = 2^6 - 1

        // Convert the raw binary segments to the appropriate ASCII encoding
        base64 += encodings[a] + encodings[b] + encodings[c] + encodings[d];
    }

    // Deal with the remaining bytes and padding
    if (byteRemainder == 1) {
        chunk = bytes[mainLength];
        a = (chunk & 252) >> 2; // 252 = (2^6 - 1) << 2

        // Set the 4 least significant bits to zero
        b = (chunk & 3) << 4; // 3   = 2^2 - 1

        base64 += encodings[a] + encodings[b] + '==';
    } else if (byteRemainder == 2) {
        chunk = (bytes[mainLength] << 8) | bytes[mainLength + 1];
        a = (chunk & 64512) >> 10; // 64512 = (2^6 - 1) << 10
        b = (chunk & 1008) >> 4; // 1008  = (2^6 - 1) << 4

        // Set the 2 least significant bits to zero
        c = (chunk & 15) << 2; // 15    = 2^4 - 1

        base64 += encodings[a] + encodings[b] + encodings[c] + '=';
    }
    return base64;
};

NodeJSによるサーバー実装

NodeJSではWebサーバーとして機能させつつ、アップロードされたファイルにリネーム処理を行います。これはファイルアップロードが同じタイミングで行われたときに、ファイル名の競合が起きないようにするためです。また、HTML/JavaScriptからのファイルアップロードに対してはformidableを使ってFormDataをパースしています。

アップロード処理のリターンはタイムスタンプ付きのファイル名を返すようにしており、これを後のJavaScriptでのMPS呼び出しの引数にしています。

app.js(一部)
var express = require('express');
var app = express();
var path = require('path');
var fs = require('fs');
var formidable = require('formidable');
var http = require('http');
var https = require('https');

app.use(express.static(path.join(__dirname, 'public')));

// For Express 4
var bodyParser = require('body-parser');
app.use(bodyParser.urlencoded({limit: '50mb', extended: false}));    // to support URL-encoded bodies
app.use(bodyParser.json({limit: '50mb'}));       // to support JSON-encoded bodies {limit: '50mb'} to avoid request entity too large

app.get('/', function (req, res) {
    res.sendFile(path.join(__dirname, 'public/index.html'));
});

// For static image
app.post('/uploadImage', function (req, res) {
    // create an incoming form object
    var form = new formidable.IncomingForm();

    // specify that we want to allow the user to upload multiple files in a single request
    form.multiples = true;

    // store all uploads in the /uploads directory
    form.uploadDir = path.join(__dirname, '/uploads');

    // every time a file has been uploaded successfully,
    // rename it to it's orignal name with time stamp to avoid the same name conflicts
    form.on('file', function (field, file) {
        var newFileName = String(Date.now()) + file.name;
        fs.renameSync(file.path, path.join(form.uploadDir, newFileName));
        res.write(newFileName);
    });

    // log any errors that occur
    form.on('error', function (err) {
        console.log('An error has occured: \n' + err);
    });

    // once all the files have been uploaded, send a response to the client
    form.on('end', function () {
        res.end();
    });

    // parse the incoming request containing the form data
    form.parse(req);
});

MATLABコードをコンパイル

ディープラーニングの推論を行うMATLABスクリプトをコンパイルします。ニューラルネットワークの読み込みは時間が掛かるのですが、MATLAB Production Serverはアプリケーションを常駐させるので、persistent変数を使ってメモリー上に保持させることで2回目以降の処理を高速化できます。
スマートフォンのカメラで撮影された画像は回転属性が含まれているのでOrientationで判別します。ここでisfield関数を使って、Orientationというフィールドがあるかどうかの判別を行っています。

入力は画像ファイルのパスと、ニューラルネットワークの指定、の2つです。
関数の出力はresultsという1変数だけで、Figureをバイト配列にしたものを返しています。

myPrediction_MPS.m
function results = myPrediction_MPS(imageFile, netsw)
% Load pre-trained network(for feature extraction)
persistent net myNet pro frame frame1 scores fig
if isempty(net)
    % Load AlexNet
    %net = alexnet();
    net = load('net.mat');
    net = net.net;

    % Load retrained network
    myNet = load('myNet.mat');
    myNet = myNet.myNet;

    pro = '.   Prob: ';
    frame = zeros(360,540,3);
    frame1 = zeros(227,227,3);
    scores =zeros(1,10);
    fig = figure('Visible','off');
    fig.Color = 'white';
end
if exist(imageFile) ~= 2
    error('File %s not found', imageFile);
end

%%
% get image
% Resize the image to avoid too small text label from iPhone
frame = imread(imageFile);
sz = size(frame);

% Get image meta information
info = imfinfo(imageFile);
% If an image is jpeg and contains exif
if isfield(info, 'Orientation')
    if info.Orientation == 6
        % If orientation is 6 (right-top), rotate image
        frame = imrotate(frame, -90);
        frame = imresize(frame, [540 360]);
    else
        % In case of info.Orientation == 1
        frame = imresize(frame, [360 540]);
    end
else
    % In case of non JPEG-EXIF formats
    frame = imresize(frame, [360 540]);
end

frame1 = imresize(frame, [227 227]);% googleNet is [224 224], AlexNet is [227 227]
% get extracted features at fc7
if netsw == 1
    % Use re-trained alexnet
    [labelIdx,scores] = classify(myNet, frame1);
else
    % Use alexnet
    [labelIdx,scores] = classify(net, frame1);
end
Smax = max(scores);
fontSize = 30;
frame = insertText(frame, [0 5], [char(labelIdx) pro num2str(Smax)], 'FontSize', fontSize);

%% Return figure as bytes arrays
imshow(frame);
results = figToImStream('figHandle', fig, 'imageFormat','jpg', 'outputType','int8');
end

このmファイルをMATLABの「アプリ」から「Production Serverコンパイラ」をクリックし、エクスポートする関数に含め、パッケージ化をクリックしてコンパイルします。コンパイルがうまくいくと、MPS用のアプリケーション(CTFファイル)が作成されますので、これを運用サーバーのMPSにアップロードします。
201888_134312.png

クラウドへ展開

作成したアプリケーションをクラウドへ展開します。今回はAWSを使ってElastic IPとAmazon Route 53を用いて、グローバルIPとDNSを対応付けているので、例えばある1日だけ大規模なイベントがあってこのアプリケーションへのアクセス数が激増する可能性がある場合に、一時的にインスタンスをCPUコア数が多く、RAMが大きいスペックの高いものに切り替える際も、引き続き同じURL(https://deeplearning.mwlab.io/)でアクセスできるようになり、クラウドならではの柔軟なスケーラビリティのメリットを得ることができます。

MPSのアプリケーションのルートディレクトリ(例:C:\Work\MPS_Dashboard_Work\mps_workspace\Instances\mps_1)とNodeJSのルートディレクトリ(例:C:\Work\WebApps)の間でシンボリックリンクを貼っておき、NodeJSにアップロードされた画像ファイルがMPSからアクセスできるようにします。

Windows環境ならコマンドプロンプトを管理者権限で起動し、以下のコマンドを実行します。
cd C:\Work\MPS_Dashboard_Work\mps_workspace\Instances\mps_1
mklink /D uploads C:\Work\WebApps\uploads

スマホからアクセス

試しにスマホでWebブラウザからこのWebアプリケーションのURLにアクセスし、硬貨の写真を撮って判定させてみます。

学習済みのAlexNet(ILSVRCで優勝した際のパラメータが事前学習された1000クラスを分類するネットワーク)をそのまま使うメリットとしては、1000個のクラスが含まれているので汎用的な分類に強いことです。ただし、クラスに含まれていないものを分類しようとしてもうまくいきません。例えば、硬貨の画像を分類させると、
result_AlexNet.jpg

puck(パック。アイスホッケー用のゴムの円盤)として分類されてしまいました。こういう時にAlexNetのネットワークを用いて独自の10クラスで転移学習させたネットワークを再学習したモデルを使うと、

result_reTrainedAlexNet.jpg

ちゃんとCoin(硬貨)として分類されました。用途に合わせて再学習したネットワークを作ることで対象とする画像について精度の高い推論結果を得ることができます。

使用したソフトウェア

開発用PC

  • MATLAB
  • Computer Vision System Toolbox
  • Neural Network Toolbox
  • MATLAB Compiler
  • MATLAB Compiler SDK

運用サーバー

  • NodeJS
  • Let's Encrypt (SSL証明書発行のため)
  • MATLAB Production Server
  • MATLAB Runtime