LoginSignup
2
4

More than 3 years have passed since last update.

Tensorflow.js + Chart.js で学習の様子を可視化する

Last updated at Posted at 2019-07-16

はじめに

TensorFlowなどの機械学習用ライブラリを使って実際に画像認識モデルとか実装してみるものの、
精度がいくらですよとか、ラベルがいくつですよみたいな、数値の結果しか基本的には得られない。
可視化の機能もあるとはいっても、精度の遷移をグラフにしたものとかで、
とくに全くの初学者にはいまいちピンとこないんじゃないかと思う(思わない?)。

結局、機械学習(とかディープラーニング)ってなにをやってるの?
学習ってなに?
っていうと、与えたデータに対してうまいぐあいにフィッティングする曲線を自動的に獲得してるってことになる。
(データの分類だとちょっと違うけど考え方によってはまあね)

なので、機械学習がどのようなものなのかよりイメージしやすいように、
学習の様子、曲線がデータにフィッティングしていく様子を可視化してみる。
機械学習にはTensorflow.jsを使う。
データの描画にはChart.jsを使う。

学習可視化の方針

  • 学習用データとそれに対するモデル推論結果の二種類のデータ群をChart.jsで描画する。
  • Tensorflow.jsで学習用データに対するモデルをつくる。
  • モデルの学習処理中に(1 epochごとに)描画処理を呼ぶ。
  • 描画処理にてモデルによる推論結果を取得し、描画するデータを更新する。

実装

全体

https://codepen.io/dldemo/full/maJXGb
もしくは

See the Pen Tensorflow.js + Chart.js by demo (@dldemo) on CodePen.

モデルの学習処理

const EPOCHS = 100;
const BATCH_SIZE = 10;
const trainModel = async () => {
  const f = await nnModel.fit(dh.convertTensorDataX(), dh.convertTensorDataY(), {
    epochs: EPOCHS,
    batchSize: BATCH_SIZE,
    callbacks: {
      onEpochEnd: updateDrawing
    }
  });
};

上記コードの callbacks部分、onEpochEnd: updateDrwingが 1 epoch ごとにupdateDrawingの処理を呼ぶための設定。
updateDrawingに描画処理を記述すればいい。詳細は後述する。
nnModelは学習モデル。
ちなみにdh.hogehogeは描画する学習用データのTensor型を返している。

Tensorflow.jsでモデルをつくる

モデルの作成

最適化手法の設定

const optimConfig = {
  optimizer: 'adam',
  loss: 'meanSquaredError'
}
レイヤーの設定
const sequentialDenseModelConfig = {
  layers: [{
    units: 1,
    inputShape: [1]
  },{
    units: 20,
    activation: 'relu'
  },{
    units: 20,
    activation: 'relu'
  },{
    units: 20,
    activation: 'relu'
  },{
    units: 20,
    activation: 'relu'
  },{
    units: 1,
    activation: 'linear'
  }]
}
モデル作成

上記の最適化手法およびレイヤーの設定を用いてモデルを作成、コンパイルする。
ここではtf.sequentialを用いた単純なモデルをつくっている。

const createSequentialDenseModelFromConfig = (compile=true) => {
  let model = tf.sequential();
  for (key in sequentialDenseModelConfig['layers']) {
    model.add(tf.layers.dense(sequentialDenseModelConfig['layers'][key]));
  }
  if (compile) { 
    model.compile(optimConfig);
  }
  return model;
};

モデルの推論

つくったモデルに対して実際に入力を与えて結果を得ることを推論と呼んだりする。
モデルによる予測処理。

tidyPredict = (model, x) => tf.tidy(() => model.predict(x));

tf.tidyを使うことでうまくメモリ管理をしてくれる。
xに学習用データを与えることでデータに対する予測結果を得られる。描画処理でつかう。

Chart.jsでデータを描画する

もろもろの設定


  const color = Chart.helpers.color;
  const chartColors = {
    red: '#FF0000',
    blue: '#0000FF'
  };
  const scatterData = {
    datasets:[{
      label: 'train dots',
      borderColor: chartColors.red,
      backgroundColor: color(chartColors.red).alpha(0.2).rgbString(),
      pointRadius: 10,
      data: dh.originalData,
      type: 'scatter'
    },{
      label: 'predict line',
      borderColor: chartColors.blue,
      backgroundColor: color(chartColors.blue).alpha(0.2).rgbString(),
      data: dh.packPredictedDatasets(nnModel),
      type: 'scatter'
    }]
  };

datasetsに描画するデータを設定する。
ここでは学習用データとモデルの推論結果の二種類を設定する。
一つ目が学習用データ。二つ目がモデルの推論結果を使ったデータ。
お好みで配色やグラフの種類を設定できる。
type: 'scatter'で散布図となる。
dataの項目に描画するデータを設定する。
データの型は以下のように x 座標と y 座標の二つのパラメータで定めたデータの配列でよい。

class ChartDataModel {
  constructor(x, y) {
    this.x = x;
    this.y = y;
  }
}

dh.packPredictedDatasets(nnModel)で、学習用データ x とその推論結果 y をzip化して取得している。
※ Chart.js とTensorflow.jsで扱う型が異なるため、型変換してまとめている。

const zip = (array1, array2) => array1.map((_, index) => [array1[index], array2[index]]);
packPredictedDatasets = (model) => {
    const tensorData = this.convertTensorDataX();
    return zip(
      Array.from(tensorData.dataSync()),
      Array.from(this.tidyPredict(model, tensorData).dataSync())
    ).map(([x, y]) => new ChartDataModel(x, y));
  }

描画

以下のようにお決まりの構文を記述すれば上記で設定した内容に従い、canvasタグに対してグラフを描画してくれる。

  const ctx = document.getElementById('canvas').getContext('2d');
  window.chart = new Chart(ctx, {
    type: 'scatter',
    data: scatterData,
    option:{
      title: {
        display: true,
        text: 'Chart.js Scatter Chart'
      },
      scales: {
        xAxes: {
          ticks: {
            min: DOMAIN_MIN,
            max: DOMAIN_MAX,
          }
        }
      }
    }
  });

DOMAIN_hogeは描画する x 軸の変域。

描画データの更新

前述したモデルの学習処理にて 1 epoch ごとに呼ぶupdateDrawing処理では、
Chart.js で描画するデータを上書き更新処理をおこなうことで更新内容が反映される。

const updateDrawing = async (epoch, log) => {
  window.chart.data.datasets[1].data = dh.packPredictedDatasets(nnModel);
  window.chart.update();
};

おわりに改善メモ

  • 最適化手法を選択化
  • 分類の可視化
  • データのアップロード
  • 画面上でモデル作成

などなど。。

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4