LoginSignup
7
5

More than 5 years have passed since last update.

kerasで学習した結果をブラウザ上でインタラクティブなプロットとして表示する

Last updated at Posted at 2018-07-26

はじめに

一般に機械学習で回帰(重回帰)すると、学習した結果として多変数関数が得られる。
つまり説明変数$x$として、目的変数$y$を予測する関数$f(x_0, ...,x_n)$が得られる。

この関数がどんな形をしているのか知りたいと思った時、2次元上のグラフだとプロットしても一つの説明変数$x_i$に対してどのように$f$が変わっていくのかしかわからないが、各変数が変わった時に予測値がどのように変わるのか直感的に理解したい。

そこで、それぞれの説明変数をスライダーで動かして、$f$のプロットをインタラクティブに変えることができるものを作ってみた。

image.png

大まかな流れは以下の通り。

  • Kerasを使ってニューラルネットワークを実装し、回帰を実行。学習結果を保存。
  • Tensorflow.jsを使ってKerasの学習結果をJSからアクセスできるように変換。
  • JSからtensorflow.jsを使って学習結果を読み込む。
  • 学習結果を使って予測値を計算する。
  • chart.jsを使ってグラフを描画。

Kerasで回帰分析をする

ボストンの住宅価格のデータを例に回帰分析を行ってみた。
kerasのコードはこちらを参考にした。

reg.py
import numpy as np
from keras.datasets import boston_housing
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import BatchNormalization
from keras.optimizers import Adam

(x_train, y_train), (x_test, y_test) = boston_housing.load_data(test_split=0.0)

def scale_input(data):
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    return scaler.fit_transform(data)

x_train = scale_input(x_train)

def create_model():
    model = Sequential()
    model.add
    model.add(Dense(13, input_dim=13, kernel_initializer='normal', activation='tanh'))
    model.add(Dense(1, kernel_initializer='normal'))
    model.compile(loss='mean_squared_error', optimizer='adam')
    return model

m = create_model()
h = m.fit(x_train, y_train, batch_size=50, epochs=1000, verbose=1, validation_split=0.2)

m.save('my_model.h5')

詳細な説明は省くが、やっていることは以下の通りである。

  • Bostonの住宅価格のデータセットをロード boston_housing.load_data
  • 説明変数を平均0、分散1に正規化 scale_input(x_train)
  • ニューラルネットワークのモデルを構築 create_model()
  • 学習の実行 m.fit(x_train, y_train, ...)
  • 学習結果の保存 m.save('my_model.h5')

numpy,tensorflow,keras,scikit-learn をインストールし、python reg.pyで実行できる。

学習結果をTensorflow.jsを使ってJSから読めるようにする

tensorflow.jsは、Javascriptで学習を行ったり、すでに学習した結果をJSから使うためのライブラリ。
tensorflow.js公式ページ

ここではtensorflow.jsで学習結果を使うために、Kerasで使ったモデルをtensorflow.jsに変換する。
kerasのモデルをtensorflow.jsに変換するためのチュートリアルを参考に、以下の手順で変換を実行。

pip install tensorflowjs
tensorflowjs_converter --input_format keras my_model.h5 converted

すると"my_model.h5"を"converted/"というディレクトリに変換してくれる。中身をみてみると、JSONとバイナリファイルができている。

tree converted
converted/
├── group1-shard1of1
├── group2-shard1of1
└── model.json

0 directories, 3 files

tensorflow.jsを使って変換後のモデルを読む

この結果をウェブページのJSからロードするには、tensorflow.jsをCDNから読む。
以下のスクリプトタグでライブラリをロードし、以後tfという変数にアクセスできる。

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.12.0"></script>

tf.loadModel(path)で先ほど変換したモデルを読み込むことができる。

async function start() {
  const model = await tf.loadModel('converted/model.json');
  // ... 読み込んだモデルを使って何かする ...
}
start();

注意点としてtf.loadModelは非同期に実行される(Promiseが返される)ので、読み込み完了まで待ちたい場合はawaitする必要がある。

また別の注意点として、変換後のモデルを読むためにはconvertedにアクセスできないといけないので、適当なウェブサーバーを起動する必要がある。

npm install -g http-server     # ここではnpmを使ってhttp-serverをインストールしているがなにを使っても良い
http-server

モデルを使って予測値を計算

変換後のモデル$f(x_0, x_1, ..., x_i)$を使って予測値を計算する。

まずは特定の軸(ここでは$x_0$とする)を横軸にとって、縦軸に$f$をプロットすることを考える。
説明変数$x_0$のみ[-1,1]の範囲で0.1ずつ変え、他の説明変数は0とする。
まず入力点を以下のように作る。

  let x_base = (new Array(13)).fill(0);  // [0,0,...0] length=13の配列
  let xs = [];
  for(let i=-10; i<11; i++) {            // x0を[-1,1]の範囲で0.1ずつ変える
    x_base[0] = i*0.1;
    xs.push(x_base.concat()); // copy
  }
  const xt = tf.tensor(xs);   // 
  xt.print();

これを実行するとxsが以下のように計算される。

xs = [ [-1.0, 0, 0, ...., 0],
       [-0.9, 0, 0, ...., 0],
       [-0.8, 0, 0, ...., 0],
       ...
       [0.9, 0, 0, ....., 0],
       [1.0, 0, 0, ....., 0],
     ]

さらにこれをtensorflow.jsのモデルに対して計算するためには、tensor型に変換する必要がある。
数値の配列に対して tf.tensor(xs) を呼ぶとtensor型になる。(tensorに変換した後のxtのshapeは13x21となる。)

このxtに対してデルの予測値を計算するためにはmodel.predictを呼ぶ。

const predicted = model.predict(xt)
predicted.print();
//  [[24.6219616],
//   [25.430891 ],
//   [26.2986507],
//   ...,
//   [40.3675957]]

すると予測値の計算結果として21x1のtensorが得られる。このようにしてKerasで学習した結果を使って予測値を計算できる。

またpredictedから個別の値を取得するためにはget関数を使う

let y = predicted.get(3,0);    // 例えばidx=3の結果を取得する場合

chart.jsを使ってグリグリ動くグラフを作る

ここまでくればあとは可視化の部分を作るだけである。
chart.jsでぐりぐり動くグラフを作る とほぼ同様の手順で作るので詳細は省略するが、大まかな手順は以下の通り

  • 各$x_i$に対応するスライダーを作る
    • <input type="range" min="-1" max="1" value="0" step="0.1" class="slider" id="x0"> の用に作った
  • スライダーが変更された時にグラフを描画するようにイベントを設定する。
var sliders = document.getElementById("sliders");
sliders.addEventListener("input", function() {
  start(Number(document.getElementById("myRange").value));  // start()がグラフを再描画する処理
}, false);
  • グラフ描画関数startを実装する
    • sliderの値から各$x$の値を取得
    • その$x$に対して予測値$f(x)$を計算
    • chart.jsでグラフを描画

ソースコード

この記事では大まかな方針のみ書いたので、詳細な実装についてはこちらのgistを参照してほしい。

7
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
5