Edited at

HTMLファイル一つで機械学習(ニューラルネットワーク)してみる その2

1年以上前に全然門外漢なのに機械学習をさわってみたくなり以下の記事を書きました。

HTMLファイル一つで機械学習してみる

https://qiita.com/basictomonokai/items/329ab7c92a5ae49e77f8

当時と変わらず、全然門外漢でHTML,JS,CSSもドシロウトです。

ドシロウトなのにQiitaに書いていいのか悩みましたが以前だれかに「投稿した方が間違いも指摘してもらえるのでした方が良い」とアドバイスされたので書きます。

以前の記事ではBrain.jsについて書きましたがどうもその後継版が出ていたみたいです。

BrainJS/brain.js

https://github.com/BrainJS/brain.js

冒頭の以下の説明文を見るとBrain.jsの後継版であることが分かります。


This is a continuation of the harthur/brain repository (which is not maintained anymore).


皆さんには簡単すぎるのかBrainについての記事がほとんどないのと前回同様、小さいHTMLファイル一つで機械学習のテストができたので投稿します。


1.Brain後継版についてわかっていること

あんまりわかってませんがこんな感じだと思います。1~6までは以前版と同じです。

7.については以前版では出来たのかもしれませんがやり方がわかりませんでした。


  1. jsで出来ている

  2. 本当はNode.jsで使うべきだがCDN指定でブラウザ上でも動く

  3. クライアント側でローカル実行可能

  4. 入力、出力のパラメタ数はいくつでもいいらしい

  5. 入力、出力の各パラメタの取りうる値の範囲は0~1の間のみ

  6. 0~1なので各パラメータの最大値で割る、割り戻すなどの前、後処理が必要

  7. 学習済データの読込、保存ができる


2.画面

作成した画面はこんな感じになります。



  • ボタンの説明


    • 学習・・・Brainの学習開始

    • パラメタ設定・・・Brainのnetwork用パラメータを反映

    • ダウンロード・・・学習済データをJSONで保存

    • ファイル選択・・・学習済データ(JSON)の読込

    • テスト・・・学習済データを利用してテスト実行

    • クリア(3つ)・・・各種テキストエリアのクリア




  • テキストエリアの説明


    • network用パラメータ・・・ノードや隠れ層の数等を指定

    • テスト用データ・・・テスト用のinputデータを指定

    • 学習用のinput&outputデータを指定




  • 結果表示エリア


    • 新しい順に処理結果を表示




3.コード

さっそく汚いですがコードです。たったこれだけでパラメタ設定→機械学習→テスト→学習結果保存→学習結果読込までできます。とても簡単です。

処理の結果は全て結果表示エリアに表示されます。


サンプルコード

<!DOCTYPE html>

<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width">
<style>
#intest,#gaku,#inparm {
margin: 10px;
width: 80%;
}

#btnarea,#clr {
margin: 10px;
}

#kekka {
background: #ecf0f1;
}

</style>
<title>Brain</title>
<script src="https://cdn.rawgit.com/BrainJS/brain.js/master/browser.js"></script>
</head>
<body>
<div id="btnarea">
<button id="jikko">学習</button>
<button id="parm">パラメタ設定</button>
<button id="download">ダウンロード</button>
<input type="file" id="file" name="file" accept=".json">
<button id="test">テスト</button>
</div>
<div id="clr">
<button id="jikkoclr">学習クリア</button>
<button id="testclr">テストクリア</button>
<button id="parmclr">パラメタクリア</button>
</div>
<div>network用パラメータ</div>
<textarea id="inparm" rows="3" placeholder="未指定時はデフォルト"></textarea>
<div>テスト用データ</div>
<textarea id="intest" rows="4" placeholder="テスト用データ"></textarea>
<div>学習用データ</div>
<textarea id="gaku" rows="5" placeholder="学習用データ"></textarea>
<div>=== LOG(新しい順) ===</div>
<div id="kekka"></div>

<script>
var network;
var json,result,result2;
var kekka=document.getElementById('kekka');

alert('最初にbrain networkのパラメターを設定してください');

// ローカルストレージ読込
var stokey1 = 'gaku20180920';
var str = window.localStorage.getItem(stokey1);
if (str == null) {
console.log('ローカルストレージに学習情報なし');
kekka.innerHTML = 'ローカルストレージに学習情報なし'+ '<br>' + kekka.innerHTML;
} else {
if (str == '') {
console.log('ローカルストレージに学習情報が空');
kekka.innerHTML = 'ローカルストレージに学習情報が空'+ '<br>' + kekka.innerHTML;
} else {
console.log('ローカルストレージより学習情報取得');
kekka.innerHTML = 'ローカルストレージより学習情報取得' + '<br>' + kekka.innerHTML;
document.getElementById('gaku').value=str;
};
};

var stokey2 = 'test20180920';
var str2 = window.localStorage.getItem(stokey2);
if (str2 == null) {
console.log('ローカルストレージにテスト情報なし');
kekka.innerHTML = 'ローカルストレージにテスト情報なし' + '<br>' + kekka.innerHTML;
} else {
if (str2 == '') {
console.log('ローカルストレージにテスト情報が空');
kekka.innerHTML = 'ローカルストレージにテスト情報が空'+ '<br>' + kekka.innerHTML;
} else {
console.log('ローカルストレージよりテスト情報取得');
kekka.innerHTML = 'ローカルストレージよりテスト情報取得' + '<br>' + kekka.innerHTML;
document.getElementById('intest').value=str2;
};
};

var stokey3 = 'parm20180920';
var str3 = window.localStorage.getItem(stokey3);
if (str3 == null) {
console.log('ローカルストレージにパラメタ情報なし');
kekka.innerHTML = 'ローカルストレージにパラメタ情報なし' + '<br>' + kekka.innerHTML;
} else {
if (str3 == '') {
console.log('ローカルストレージにパラメタが空');
kekka.innerHTML = 'ローカルストレージにパラメタ情報が空'+ '<br>' + kekka.innerHTML;
} else {
console.log('ローカルストレージよりテスト情報取得');
kekka.innerHTML = 'ローカルストレージよりパラメタ情報取得' + '<br>' + kekka.innerHTML;
document.getElementById('inparm').value=str3;
};
};

// 学習ボタンの処理
var btn2 = document.getElementById('jikko');
btn2.addEventListener('click', function() {

// 学習用データのオブジェクト化
var gakudata = new Object();

var wkgakudata = document.getElementById('gaku').value;
window.localStorage.setItem(stokey1, wkgakudata);

gakudata = (new Function('return ' + wkgakudata))();

var date1 = new Date();
console.log('◆学習開始');
console.log(date1);
kekka.innerHTML = '◆学習開始' + '<br>' + kekka.innerHTML;
kekka.innerHTML = date1 + '<br>' + kekka.innerHTML;
network.train(gakudata);

var date2 = new Date();
console.log('◆学習終了');
console.log(date2);
kekka.innerHTML = '◆学習終了' + '<br>' + kekka.innerHTML;
kekka.innerHTML = date2 + '<br>' + kekka.innerHTML;

console.log('全学習時間:'+(date2-date1));
kekka.innerHTML = '全学習時間:'+ (date2-date1) + '<br>' + kekka.innerHTML;
json = network.toJSON();

});

// パラメータ設定の処理
var btn6 = document.getElementById('parm');
btn6.addEventListener('click', function() {

var wkparmdata = document.getElementById('inparm').value;
if (wkparmdata == ''){
// Brain
network = new brain.NeuralNetwork();
kekka.innerHTML = 'デフォルトパラメータ<br>' + kekka.innerHTML;
} else {

// パラメータデータのオブジェクト化
var parmdata = new Object();

window.localStorage.setItem(stokey3, wkparmdata);
parmdata = (new Function('return ' + wkparmdata))();
console.log(parmdata);
// Brain
network = new brain.NeuralNetwork(parmdata);
kekka.innerHTML = JSON.stringify(parmdata) + '<br>' + kekka.innerHTML;

};

console.log('◆パラメータ設定済');
kekka.innerHTML = '◆◆パラメータ設定済' + '<br>' + kekka.innerHTML;

});

// テストボタンの処理
var btn3 = document.getElementById('test');
btn3.addEventListener('click', function() {

// テスト用データのオブジェクト化
var testdata = new Object();

var wktestdata = document.getElementById('intest').value;
window.localStorage.setItem(stokey2, wktestdata);
testdata = (new Function('return ' + wktestdata))();

var date1 = new Date();
console.log('◆テスト開始');
console.log(date1);
kekka.innerHTML = '◆テスト開始' + '<br>' + kekka.innerHTML;
kekka.innerHTML = date1 + '<br>' + kekka.innerHTML;

// テスト実行1 network.run
result2 = network.run(testdata);
console.log('run結果');
console.log(result2);
kekka.innerHTML = 'run結果' + '<br>' + kekka.innerHTML;
kekka.innerHTML = JSON.stringify(result2) + '<br>' + kekka.innerHTML;
// テスト実行2 brain.likely
result = brain.likely(testdata, network);
console.log('likely結果');
console.log(result);
kekka.innerHTML = 'likely結果' + '<br>' + kekka.innerHTML;
kekka.innerHTML = result + '<br>' + kekka.innerHTML;
var date2 = new Date();
console.log('◆テスト終了');
console.log(date2);
kekka.innerHTML = '◆テスト終了' + '<br>' + kekka.innerHTML;
kekka.innerHTML = date2 + '<br>' + kekka.innerHTML;

console.log('テスト時間:'+(date2-date1));
kekka.innerHTML = 'テスト時間:'+ (date2-date1) + '<br>' + kekka.innerHTML;

});

// 学習用データクリアの処理
var btn4 = document.getElementById('jikkoclr');
btn4.addEventListener('click', function() {

document.getElementById('gaku').value = '';

});

// テスト用データクリアの処理
var btn5 = document.getElementById('testclr');
btn5.addEventListener('click', function() {

document.getElementById('intest').value = '';

});

// パラメタ用データクリアの処理
var btn7 = document.getElementById('parmclr');
btn7.addEventListener('click', function() {

document.getElementById('inparm').value = '';

});

// ダウンロードボタンの処理
var btn = document.getElementById('download');
btn.addEventListener('click', function() {
console.log('◆学習結果保存開始');
kekka.innerHTML = '◆学習結果保存開始' + '<br>' + kekka.innerHTML;

// 学習結果JSONを取得
var content = JSON.stringify(json);

// Blob形式に変換する
let blob = new Blob([content]);

// Blobデータに対するURLを発行する
let blobURL = window.URL.createObjectURL(blob);

// URLをaタグに設定する
let a = document.createElement('a');
a.href = blobURL;

// download属性でダウンロード時のファイル名を指定
if (typeof aaresult === "undefined") {
a.download="noname.json";
} else {
a.download = aaresult.name;
};

document.body.appendChild(a);

// CLickしてダウンロード
a.click();

// 終わったら不要なので要素を削除
a.parentNode.removeChild(a);

 console.log('◆学習結果保存完了');
kekka.innerHTML = '◆学習結果保存完了' + '<br>' + kekka.innerHTML;

});

// ファイル選択ボタンの処理

var el2 = document.getElementById("file");

el2.addEventListener( 'change', function(e) {
console.log('◆学習結果読込開始');
kekka.innerHTML = '◆学習結果読込開始' + '<br>' + kekka.innerHTML;
var result = e.target.files[0];
aaresult=e.target.files[0];

//FileReaderのインスタンスを作成する
var reader = new FileReader();

//読み込んだファイルの中身を取得する
reader.readAsText( result );

//ファイルの中身を取得後に処理を行う
reader.addEventListener( 'load', function() {
  console.log('◆学習結果読込完了');
kekka.innerHTML = '◆学習結果読込完了' + '<br>' + kekka.innerHTML;

var afttrain = JSON.parse(reader.result)
console.log(afttrain);
json = afttrain;
network.fromJSON(afttrain);
kekka.innerHTML = aaresult.name + ' 取得完了' + '<br>' + kekka.innerHTML;

})

});

</script>

</body>
</html>



4.network用パラメータ

network用パラメータはおそらくニューラルネットワークのチューニングだと思いますが正直よくわかっていません。

わかっていないのでBrainのページに載っていた通り、activation functionはデフォルトのままでノードは3、隠れ層4で設定しました。

{

activation: 'sigmoid', // activation function
hiddenLayers: [3,4]
}


5.学習用データ

自信がありませんがテストするためには学習用データが必要なので多分この手のデータで有名なkaggleの以下のデータを使いました。

Death in the United States(2015)

https://www.kaggle.com/cdc/mortality/version/2#2015_data.csv

悪趣味と言われそうですが一番加工しやすそうなので選らんだだけです。

データ総件数271万件、その内自殺者のデータのみを抽出して約44,000件を元データにして正規化、性別、年齢、既婚/未婚をinput、没曜日をoutputにした学習データを作成しました。

学習データ約44,000件の一部が以下です。

※性別の項目名はmaleにしたかったのですがmailにTYPOしてます・・・

[

{input: {mail:1,age:0.133,marry:0}, output:{mon:1}},
{input: {mail:1,age:0.142,marry:0}, output:{thu:1}},
{input: {mail:1,age:0.142,marry:0}, output:{tue:1}},
{input: {mail:1,age:0.142,marry:0}, output:{mon:1}},
]


6.テスト用データ

テストするためにはテスト用データも必要なので以下のデータを使いました。学習用データのinput側のみのデータです。

テストデータは「男性、55歳、既婚」の意味になります。

{mail:1,age:0.416,marry:1}


7.実行結果

実行結果は以下です。


  • 全学習時間:2,604,404ミリ秒(約44,000件、約43分)

  • テスト時間:88ミリ秒

  • run結果(JSONで返却):

{"tue":0.1497674286365509,"mon":0.13587155938148499,"sat":0.13362683355808258,"fri":0.14742684364318848,"sun":0.12110204994678497,"wed":0.14535333216190338,"thu":0.15461139380931854}


  • likely結果:thu(木曜日)

画面表示は以下でした。


8.おわりに

そもそもこのやり方で良いのかよく分かってませんがBrain.js自体は実行することが出来ています。

間違い等あればドシロウトですが教えていただけると幸いです。

以上です。