私が所属している研究室の先輩方が去年作成したJavascriptの機械学習ライブラリTempura
(https://github.com/mil-tokyo/tempura) を試してみる。
私もGMM周りで少しだけコミットしていますが、全体的に把握できているわけではない。
主に線形識別器やクラスタリング用の手法が用意されており。デモとドキュメントが英語でしっかり書かれているので,解読しながら少し遊んでみる。
インストール
まずはインストール
サブモジュールとしてJavaScriptの行列演算ライブラリSushiが使われているため,追加する。
cd your-workspace
git clone git@github.com:mil-tokyo/tempura.git
cd tempura
git submodule init
git submodule update
python compile.py
デモ
tempura/demo/index.html (もしくは http://mil-tokyo.github.io/tempura/demo/index.html )
をブラウザで開くとデモを見ることができる。
ここで行われているデモは
・パーセプトロン
・SGD-SVM
・Nearest-Neighbor
・GMM
の四種類である。これらに引き渡すパラメータをブラウザ上で編集して試すことができ,またそのデモにおいて用いられているソースコードを確認することができる。
使い方
compile.pyを実行した時に生成される
tempura/bin/tempura.js
を読み込むだけで様々な識別器を使えるようなので試してみた。
サンプルデータを用意しパーセプトロンを使って学習,識別を行うコードが下記である。
// トレーニング用の行列と答を用意する
var train = Sushi.Matrix.fromArray([[0,4],[1,4],[-2,1],[-1,2]]);
var labels = Sushi.Matrix.fromArray([[1,1,-1,-1]]);
// 識別器(パーセプトロン)を初期化する
var perceptron = new Tempura.LinearModel.Perceptron({center:true});
// 識別器の学習を行う
perceptron.fit(train, labels.t());
// 学習した識別器の重みを表示する
perceptron.weight.print();
// テストデータを用意する
var sample = Sushi.Matrix.fromArray([[0,3],[-3,2]]);
// テストデータを識別し,推定クラスを表示する
var pred = perceptron.predict(sample);
pred.print();
二次元のトレーニングデータを用意し,パーセプトロンで識別,他のデータを識別した結果を表示するというコードである。書き方も簡単で実に直感的だ。
パーセプトロンの主な使い方や,他の手法に関するドキュメントは (http://mil-tokyo.github.io/tempura/docs/index.html) に掲載されている。
##おまけ
javascriptでかかれても実行の仕方わかんないよ,という人向けにhtmlで用意してみたので,tempura直下に置いて試してみてね。
<!DOCTYPE html>
<html>
<head>
<title>test</title>
<meta charset="UTF-8">
<script src="./sushi/src/sushi.js"></script>
<script src="./bin/tempura.js"></script>
</head>
<body>
<h1>Test</h1>
<script>
// トレーニング用の行列と答を用意する
var train = Sushi.Matrix.fromArray([[0,4],[1,4],[-2,1],[-1,2]]);
var labels = Sushi.Matrix.fromArray([[1,1,-1,-1]]);
// 識別器(パーセプトロン)を初期化する
var perceptron = new Tempura.LinearModel.Perceptron({center:true});
// 識別器の学習を行う
perceptron.fit(train, labels.t());
// 学習した識別器の重みを表示する
perceptron.weight.print();
// テストデータを用意する
var sample = Sushi.Matrix.fromArray([[0,3],[-3,2]]);
// テストデータを識別し,推定クラスを表示する
var pred = perceptron.predict(sample);
pred.print();
</script>
</body>
</html>
リンク
作成したJava Scriptライブラリ一覧
Sushi(行列): https://github.com/mil-tokyo/sushi
Tempura(機械学習): https://github.com/mil-tokyo/tempura
Sukiyaki(ディープラーニング): https://github.com/mil-tokyo/sukiyaki