Chainer
pix2pix
WebDNN

WebDNNで猫の線画着色モデルを動作させる

More than 1 year has passed since last update.

この記事について

以前作成した、pix2pixによる猫画像の線画着色モデルを、ブラウザ側で動作させるようにしました。そこに至るまでの軌跡を解説します。

動作サイト (pix2pix.daio.net)

これに伴い、今までサーバーサイドで動作させていたpix2pixは止めました。

WebDNN

既存のフレームワークで学習したモデルをWebブラウザから利用できるようにしたものとして、deeplearnjsWebDNNがあります。

当初はTensorFlow実装のpix2pixモデルをそのまま使うことを考えていたため、deeplearnjsを試すことを考えていました。しかしJSで計算グラフの定義をする必要がある(pbtxtを読んでグラフを勝手に作ってくれるとかはしない)ので、いったん採用を保留しました。

次にWebDNNの利用を検討しました。WebDNNはMakeGirls.moeで既に利用されており実績があります。しかしながら、TensorFlowの対応は未だ実験的ということだったので、kerasやchainerを使うことにしました。

pix2pix keras実装

kerasで記述されたpix2pix実装は複数ありますが、直接利用するには問題がありました。

一応vess2retベースのモデルでWebDNNの動作も出来はしたのですが、一部うまくいかない部分があり(ちょっと記憶がおぼろげ)chainerを使うことにしました。

pix2pix chainer実装

chainerで実装されたpix2pixを探したところ、PFN公式実装が存在していました。

この実装をベースに、入力3チャンネルのカラー画像が受けられ、以前作成した訓練画像が直接読み込めるよう手を入れたバージョンを作成しました。

スクリプト

自分のリポジトリには、訓練データや訓練済みモデルなどをダウンロードするスクリプトを用意してあります。

単純にブラウザで試すだけであれば、WebDNNのwebassembly backendモデルを取得、展開するget-webdnn-model.shを実行するだけで良いです。

データセットからモデルの訓練をはじめから行うには、以下のスクリプトを実行してください。

動作

webassemblyはローカルから直接読んで実行することを禁止しているブラウザが大半のようなので、http経由でhtml, outディレクトリが見える場所に置いてhtml/index.htmlを読み込んでください。

ページを読み込むとモデル読み込みが自動で始まります。読み込みが完了したら、左側のキャンバスに線画を書いて"convert"を押すことでpix2pixモデルの計算が始まります。i5-3750を積んだマシンとDebian stretch, Chrome 61の組み合わせで計算に25秒程度かかります。Nextbit Robin(Snapdragon 808)で106秒程度、OnePlus3T(Snapdragon 821)で47秒程度でした。

"preset"ボタンを押すとこちらで用意しておいたプリセット線画が設定されます。訓練画像と同じものなので、これを使って変換をさせるとかなりきれいな結果が出ます。

image.png

ここに至るまでの道のり

いくつかの勘違い

WebDNNの各種APIは非同期動作するものが多くPromiseが使われているので、その点を踏まえる必要がありました。.then()が使えるのでそちらを活用しましょう。

また、順方向計算では不要な処理(batch normalizationなど)は取り除く必要があります。最初これに気づかず、Issueを立ててしまいました。mnistサンプルやチュートリアルには特に記載がないのですが、ResNetのサンプルにはchainer.using_config('train', False)を使った例があります。

Deconvolution2Dのバグ

pix2pixのU-Net decoder側にはDeconvolution2Dが使われているのですが、当初このコードは適切に動いていませんでしたAutoEncoderを使った再現コードを提示したところ、とてもすばやく対応していただけました。

これだけでなく、私の勘違いにも親切にご指摘を頂いて、WebDNN開発陣には大変感謝しております。

pix2pixのアクティベーション処理

PFN公式実装は、Generator/Discriminatorの最終出力にアクティベーション処理が入っていませんでした。そのため、Gの出力を最後にクリッピングするなどの処理が入っていたのですが、この辺りはTensorFlow/Keras実装に合わせてG側にtanh, D側にsigmoidを入れています。
特にあまり性能差を感じないのですが、出力値が0〜1に必ず収まるのでJS側でクリッピング処理が不要になるという点でこちらのほうが個人的には望ましいです。

個人的に気を付けたこと

今どきのWeb回りにはあまり詳しくないのですが、大きなフレームワークを導入するとやっていることがわかりづらくなるので、できるだけそういったものには頼らないようにしました。
唯一キャンバスで絵を描くためにTeledraw Canvasを使っています。このあたりはpix2pix tensorflow実装のコードを借りてくるという手もあったのですが、むしろそのためのコードがちょっと増えて見通しが悪くなるような気がします。

JavaScript側の実装

今回は8bitエンコーディングされた(圧縮された)モデルを用いますので、webdnn.jsに加えてzlib.jsを読み込んでいます。

モデルの読み込み

WebDNN.loadの引数にモデルファイルのあるパスを指定するだけで読み込みが開始されます。progessCallbackというコールバック関数を与えられる引数があります。これを指定すると、モデルの読み込み処理中に随時関数を呼び出すことができ、プログレスバーの表示などが行なえます。
progressCallbackで呼ばれる関数の引数は2つあり、現在まで読み込んだバイト数とモデル全体のバイト数になります。よって、これらの割合を計算し表示に反映させることでプログレスバーを実現しています。

// プログレスバー用コールバック
function callback_progressbar(current, total) {
    var pct = Math.round(current / total * 100);
    p.p = pct;
    p.update();
}

// モデル読み込みをonload時に実施
window.onload = function() {
    model = WebDNN.load("../out", {progressCallback: callback_progressbar}).then(function (r) {
        var stat = document.querySelector("#status");
        stat.textContent = "Loaded.";
        runner = r; // 推論用のDiscripterRunnerインスタンス
        var v = document.getElementById("prog-bar");
        v.style.visiblility = "hidden";
    });

推論

基本的な流れは以下になります。

  • 線画用canvasからビットマップ情報を取得
    • getImageArrayメソッド(非同期動作)
    • scaleはRGBそれぞれの上限値を与える。浮動小数点数型で0~1の値に変換した結果が得られる
    • orderはデータの並び純を指定する。今回はChainerを使っているため、Channel, Height, Widthの順でデータを扱う必要がある
  • モデルの入力、出力バッファを取得
  • 入力バッファに情報を設定
  • 推論実行 (非同期動作)
    • 結果の取得
    • 出力用canvasに値を設定
    • 出力値も0~1となるため、scaleを与えて0~255の範囲内に変換
    • 出力データの並び順もChainer準拠であることを明示
function start() {
    console.log("start");
    var img = document.querySelector("#input");
    WebDNN.Image.getImageArray(img, { scale: [255, 255, 255], order: WebDNN.Image.Order.CHW}).then(function (img) { // canvasから値の読み込み
        console.log("runner called");
        display_inference();
        startTimer();
        var x = runner.getInputViews()[0]; // 入力バッファ
        var y = runner.getOutputViews()[0]; // 出力バッファ
        // console.log(img)
        x.set(img); // 画像を入力バッファに代入
        runner.run().then(function () { // 推論
            console.log("output");
            var ret = y.toActual(); // 推論結果
            var cvs = document.querySelector("#output");
            var clip = [];
            // console.log(ret);
            WebDNN.Image.setImageArrayToCanvas(ret, 256, 256, cvs, { // 結果をcanvasに反映
                scale: [255, 255, 255],
                bias: [0, 0, 0],
                color: WebDNN.Image.Color.RGB,
                order: WebDNN.Image.Order.CHW
            });
            stopTimer();
            hide_inference();
        });
    });
}

TODO

  • ~JS側のコード解説~
  • 変換手続き
    • 圧縮(8bit encoding)