この記事について
以前作成した、pix2pixによる猫画像の線画着色モデルを、ブラウザ側で動作させるようにしました。そこに至るまでの軌跡を解説します。
これに伴い、今までサーバーサイドで動作させていたpix2pixは止めました。
WebDNN
既存のフレームワークで学習したモデルをWebブラウザから利用できるようにしたものとして、deeplearnjsやWebDNNがあります。
当初はTensorFlow実装のpix2pixモデルをそのまま使うことを考えていたため、deeplearnjsを試すことを考えていました。しかしJSで計算グラフの定義をする必要がある(pbtxtを読んでグラフを勝手に作ってくれるとかはしない)ので、いったん採用を保留しました。
次にWebDNNの利用を検討しました。WebDNNはMakeGirls.moeで既に利用されており実績があります。しかしながら、TensorFlowの対応は未だ実験的ということだったので、kerasやchainerを使うことにしました。
pix2pix keras実装
kerasで記述されたpix2pix実装は複数ありますが、直接利用するには問題がありました。
-
https://github.com/tdeboissiere/DeepLearningImplementations/tree/master/pix2pix
- keras 1.x用
- ちょっとした手直しだけではkeras2で動かすのは困難
- 訓練データの与え方やデフォルトのハイパーパラメータが異なる
- 学習率初期値が論文より大きめ
- 手持ちのデータでうまく訓練できなかった
-
https://github.com/costapt/vess2ret
- keras 1/2両対応の実装
- 入出力サイズが512x512と大きい
- 画像のshapeを無理やりthanoベースに合わせている
一応vess2retベースのモデルでWebDNNの動作も出来はしたのですが、一部うまくいかない部分があり(ちょっと記憶がおぼろげ)chainerを使うことにしました。
pix2pix chainer実装
chainerで実装されたpix2pixを探したところ、PFN公式実装が存在していました。
-
https://github.com/pfnet-research/chainer-pix2pix
- facadeデータセットのみ対応
- 入力12チャンネル(12クラス)
- facadeデータセットのみ対応
この実装をベースに、入力3チャンネルのカラー画像が受けられ、以前作成した訓練画像が直接読み込めるよう手を入れたバージョンを作成しました。
スクリプト
自分のリポジトリには、訓練データや訓練済みモデルなどをダウンロードするスクリプトを用意してあります。
単純にブラウザで試すだけであれば、WebDNNのwebassembly backendモデルを取得、展開するget-webdnn-model.shを実行するだけで良いです。
データセットからモデルの訓練をはじめから行うには、以下のスクリプトを実行してください。
- get-trained-data.sh
- https://github.com/knok/pix2pix-tensorflow/releases/tag/v1.0 からデータセットを取得、展開します
- train.sh
- データセットに基づき訓練を実行します
- dump.sh
- train.shで生成されたchainerモデルをWebDNNに変換します
動作
webassemblyはローカルから直接読んで実行することを禁止しているブラウザが大半のようなので、http経由でhtml, outディレクトリが見える場所に置いてhtml/index.htmlを読み込んでください。
ページを読み込むとモデル読み込みが自動で始まります。読み込みが完了したら、左側のキャンバスに線画を書いて"convert"を押すことでpix2pixモデルの計算が始まります。i5-3750を積んだマシンとDebian stretch, Chrome 61の組み合わせで計算に25秒程度かかります。Nextbit Robin(Snapdragon 808)で106秒程度、OnePlus3T(Snapdragon 821)で47秒程度でした。
"preset"ボタンを押すとこちらで用意しておいたプリセット線画が設定されます。訓練画像と同じものなので、これを使って変換をさせるとかなりきれいな結果が出ます。
ここに至るまでの道のり
いくつかの勘違い
WebDNNの各種APIは非同期動作するものが多くPromiseが使われているので、その点を踏まえる必要がありました。.then()が使えるのでそちらを活用しましょう。
また、順方向計算では不要な処理(batch normalizationなど)は取り除く必要があります。最初これに気づかず、Issueを立ててしまいました。mnistサンプルやチュートリアルには特に記載がないのですが、ResNetのサンプルにはchainer.using_config('train', False)を使った例があります。
Deconvolution2Dのバグ
pix2pixのU-Net decoder側にはDeconvolution2Dが使われているのですが、当初このコードは適切に動いていませんでした。AutoEncoderを使った再現コードを提示したところ、とてもすばやく対応していただけました。
これだけでなく、私の勘違いにも親切にご指摘を頂いて、WebDNN開発陣には大変感謝しております。
pix2pixのアクティベーション処理
PFN公式実装は、Generator/Discriminatorの最終出力にアクティベーション処理が入っていませんでした。そのため、Gの出力を最後にクリッピングするなどの処理が入っていたのですが、この辺りはTensorFlow/Keras実装に合わせてG側にtanh, D側にsigmoidを入れています。
特にあまり性能差を感じないのですが、出力値が0〜1に必ず収まるのでJS側でクリッピング処理が不要になるという点でこちらのほうが個人的には望ましいです。
個人的に気を付けたこと
今どきのWeb回りにはあまり詳しくないのですが、大きなフレームワークを導入するとやっていることがわかりづらくなるので、できるだけそういったものには頼らないようにしました。
唯一キャンバスで絵を描くためにTeledraw Canvasを使っています。このあたりはpix2pix tensorflow実装のコードを借りてくるという手もあったのですが、むしろそのためのコードがちょっと増えて見通しが悪くなるような気がします。
JavaScript側の実装
今回は8bitエンコーディングされた(圧縮された)モデルを用いますので、webdnn.jsに加えてzlib.jsを読み込んでいます。
モデルの読み込み
WebDNN.loadの引数にモデルファイルのあるパスを指定するだけで読み込みが開始されます。progessCallbackというコールバック関数を与えられる引数があります。これを指定すると、モデルの読み込み処理中に随時関数を呼び出すことができ、プログレスバーの表示などが行なえます。
progressCallbackで呼ばれる関数の引数は2つあり、現在まで読み込んだバイト数とモデル全体のバイト数になります。よって、これらの割合を計算し表示に反映させることでプログレスバーを実現しています。
// プログレスバー用コールバック
function callback_progressbar(current, total) {
var pct = Math.round(current / total * 100);
p.p = pct;
p.update();
}
// モデル読み込みをonload時に実施
window.onload = function() {
model = WebDNN.load("../out", {progressCallback: callback_progressbar}).then(function (r) {
var stat = document.querySelector("#status");
stat.textContent = "Loaded.";
runner = r; // 推論用のDiscripterRunnerインスタンス
var v = document.getElementById("prog-bar");
v.style.visiblility = "hidden";
});
推論
基本的な流れは以下になります。
- 線画用canvasからビットマップ情報を取得
- getImageArrayメソッド(非同期動作)
- scaleはRGBそれぞれの上限値を与える。浮動小数点数型で0~1の値に変換した結果が得られる
- orderはデータの並び純を指定する。今回はChainerを使っているため、Channel, Height, Widthの順でデータを扱う必要がある
- モデルの入力、出力バッファを取得
- 入力バッファに情報を設定
- 推論実行 (非同期動作)
- 結果の取得
- 出力用canvasに値を設定
- 出力値も0~1となるため、scaleを与えて0~255の範囲内に変換
- 出力データの並び順もChainer準拠であることを明示
function start() {
console.log("start");
var img = document.querySelector("#input");
WebDNN.Image.getImageArray(img, { scale: [255, 255, 255], order: WebDNN.Image.Order.CHW}).then(function (img) { // canvasから値の読み込み
console.log("runner called");
display_inference();
startTimer();
var x = runner.getInputViews()[0]; // 入力バッファ
var y = runner.getOutputViews()[0]; // 出力バッファ
// console.log(img)
x.set(img); // 画像を入力バッファに代入
runner.run().then(function () { // 推論
console.log("output");
var ret = y.toActual(); // 推論結果
var cvs = document.querySelector("#output");
var clip = [];
// console.log(ret);
WebDNN.Image.setImageArrayToCanvas(ret, 256, 256, cvs, { // 結果をcanvasに反映
scale: [255, 255, 255],
bias: [0, 0, 0],
color: WebDNN.Image.Color.RGB,
order: WebDNN.Image.Order.CHW
});
stopTimer();
hide_inference();
});
});
}
TODO
-
- 変換手続き
- 圧縮(8bit encoding)