JavaScript

Javascript で ニューラルネット(XOR回路を作ってみる)

ちょっと前に書いたコードのメモ:

XOR(排他的論理和)を、ニューラルネットに学習させる例

<XOR>
True と True で False  (1 & 1 -> 0)
True と False で True   (1 & 0 -> 1)
False と True で True   (0 & 1 -> 1)
False と False で False  (0 & 0 -> 0)

結果
3/0(0),0(1),_0.009335834287899665(0),
2/0(0),1(1),_0.9981023530577846(0),
1/1(0),1(1),_0.010547904797307581(0),
0/1(0),0(1),_0.9859294247263731(0),

<結果の見方> ・・・ XOR回路が学習できている
0,0 →  0.00933
0,1 →  0.99810
1,1 →  0.01054
1,0 →  0.98592

myjs.html
<!doctype html>
<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
<link rel="stylesheet" type="text/css" charset="utf-8" href="myjs.css">
<script type="text/javascript" charset="utf-8" src="./myjs.js"></script>
<script type="text/javascript">
onload = function() {
    document.getElementById('program1').innerText = program1;
    program1();
    }
</script>
<title>ニューラルネットワークをJavaScriptで</title>
</head>
<body>
<h3>ソースコード</h3>
<pre class="source200" id="program1"></pre>
<h3>結果</h3>
<pre class="source100" id="result1"></pre>
</body>
</html>

myjs.css
@charset 'utf-8';
body {
}
pre {
    border: 1px solid #000000;
    overflow: scroll;
    line-height: 100%;
}
pre.source200 {
    max-height: 20em;
}
pre.source100 {
    max-height: 10em;
}
myjs.js
function program1() {
    var iINPUT = 2;
    var iHIDDEN = 8;
    var iOUTPUT = 1;
    var iPR = 1000;
    var iMAX_T = 500000;
    var dETA = 2.5;
    var dEPS = 0.0001;
    var dALPHA = 0.92;
    var dBETA = 0.35;
    var dW0 = 0.9;
    var xi = new Array(iINPUT);
    var v = new Array(iHIDDEN);
    var o = new Array(iOUTPUT-1);
    var zeta = new Array(iOUTPUT-1);

    //var w1;// = new Array();//(iHIDDEN - 1, iINPUT);
    //var w2;// = new Array();//(iOUTPUT - 1, iHIDDEN);
    var w1 = new Array();
    for(var a1=0;a1 < iHIDDEN;a1++){w1[a1] = new Array();for(var a2=0;a2 <= iINPUT;a2++){w1[a1][a2] = 0;}}
    var w2 = new Array();
    for(var a1=0;a1 < iOUTPUT;a1++){w2[a1] = new Array();for(var a2=0;a2 <= iHIDDEN;a2++){w2[a1][a2] = 0;}}
    //var d_w1 = new Array();//(iHIDDEN - 1, iINPUT);
    //var d_w2 = new Array();//(iOUTPUT - 1, iHIDDEN);
    var d_w1 = new Array();
    for(var a1=0;a1 < iHIDDEN;a1++){d_w1[a1] = new Array();for(var a2=0;a2 <= iINPUT;a2++){d_w1[a1][a2] = 0;}}
    var d_w2 = new Array();
    for(var a1=0;a1 < iOUTPUT;a1++){d_w2[a1] = new Array();for(var a2=0;a2 <= iHIDDEN;a2++){d_w2[a1][a2] = 0;}}
    var pre_dw1 = new Array();//(iHIDDEN - 1, iINPUT);
    var pre_dw2 = new Array();//(iOUTPUT - 1, iHIDDEN);
    var pre_dw1 = new Array();
    for(var a1=0;a1 < iHIDDEN;a1++){pre_dw1[a1] = new Array();for(var a2=0;a2 <= iINPUT;a2++){pre_dw1[a1][a2] = 0;}}
    var pre_dw2 = new Array();
    for(var a1=0;a1 < iOUTPUT;a1++){pre_dw2[a1] = new Array();for(var a2=0;a2 <= iHIDDEN;a2++){pre_dw2[a1][a2] = 0;}}
    var data;// = new Array();   //iPATTERNz - 1, iINPUT - 1
    var d_data;// = new Array(); //iPATTERNz - 1, iOUTPUT - 1
    var t_data;// = new Array(); //iPATTERNo - 1, iINPUT - 1
    var iPATTERNz =0;
    var iPATTERNo =0;
    log_out('Start\n');
    (function main() {
        load_data();

        back_propagation_main();
        tryTest();
    })();
    log_out('End\n');
    function load_data() {
        iPATTERNz = 8
        iPATTERNo = 4
        var p =0;
        var ii =0;  //Inputの数
        var iz =0;  //Zeta=Outputの数
        var id =0;

        //入力する値
        data = [[0,0],[0,1],[1,0],[1,1],[1,0],[1,1],[0,1],[0,0]];

        //教師信号(所謂、模範解答)
        t_data =[[0],[1],[1],[0],[1],[0],[1],[0]];
        //お試し用の問題
        d_data =[[1,0],[1,1],[0,1],[0,0]];
        return null;
    }
    function back_propagation_main() {
    var t = 0;
    var p = 0;
    var E = 0.0;
    var Esum = 0.0;
        w_init();

        for (t = 0; t < iMAX_T; t++) {

            dw_init();
            Esum = 0;

            for (p = 0; p < iPATTERNz; p++) {
                xi_set(p);
                forward();
                backward();
                Esum = Esum + calc_error();
            }

            w_modify();

            E = Esum / (iOUTPUT * iPATTERNz);

            if( t % iPR == 0 ) {
                log_out( t + '/' + iMAX_T + ' ' + E + '\n');
            }
            if(E < dEPS){
                break;
            }
        }


        return null;
    }
    function tryTest() {
        for(var p=0; p<iPATTERNo; p++){
            var s='';
            s = s + p + '/';
            for(var k=0; k<iINPUT; k++){
                xi[k] = d_data[p][k];
                s = s + xi[k] +'(' + k + '),';
            }
            s = s + '_';
            forward();
            for(var j2=0;j2<iOUTPUT;j2++){
                s = s + o[j2] +'(' + j2 + '),';
            }
            s = s + '\n';
            log_out(s);
        }
    }
    function w_init() {

    for(var j = 0; j < iHIDDEN; j++){
        for(var k = 0; k <= iINPUT; k++){
        w1[j][k] = arand();
        d_w1[j][k] = 0.0;
        }
    }

    for(var j2 = 0 ; j2 < iOUTPUT ; j2++){
        for(var j = 0 ; j <= iHIDDEN ; j++){
        w2[j2][j] = arand();
        d_w2[j2][j] = 0.0;
        }
    }
        return null;
    }
    function dw_init() {
    for(var j = 0 ; j < iHIDDEN;j++){
        for(var k = 0 ; k <= iINPUT; k++){
        pre_dw1[j][k] = d_w1[j][k]
        d_w1[j][k] = 0.0;
        }
    }
    for(var j2 = 0 ; j2 < iOUTPUT; j2++){
        for(var j = 0 ; j<= iHIDDEN; j++){
        pre_dw2[j2][j] = d_w2[j2][j]
        d_w2[j2][j] = 0.0;
        }
    }
        return null;
    }
    function xi_set(p) {
    for(var k = 0 ; k < iINPUT; k++){
        xi[k] = data[p][k];
    }
    xi[iINPUT] = 1.0;
    for(var i = 0 ; i < iOUTPUT; i++){
        zeta[i] = t_data[p][i];
    }
        return null;
    }
    function forward() {
    for(var j = 0 ; j < iHIDDEN; j++){
        var sum = 0;
        for(var k = 0 ; k <= iINPUT; k++){
            sum = sum + xi[k] * w1[j][k];
        }
        v[j] = sigmoid(sum);
    }
    v[iHIDDEN] = 1.0;
    for(var j2 = 0 ; j2 < iOUTPUT; j2++){
        var sum = 0;
        for(var j = 0 ; j <= iHIDDEN; j++){
            sum = sum + v[j] * w2[j2][j];
        }
        o[j2] = sigmoid(sum);
    }
        return null;
    }
    function backward() {
    var delta2 = new Array(iOUTPUT);
    var delta1 = new Array(iHIDDEN + 1);
    for(var i = 0 ; i < iOUTPUT; i++){
        delta2[i] = dBETA * o[i] * (1 - o[i]) * (zeta[i] - o[i]);
    }
    for(var j = 0 ; j < iHIDDEN; j++){
        var sum =0;
        for(var j2 = 0 ; j2 < iOUTPUT; j2++){
            sum = sum + w2[j2][j] * delta2[j2];
            delta1[j] = dBETA * v[j] * (1 - v[j]) * sum;
        }
    }
    for(var j2 = 0 ; j2 < iOUTPUT; j2++){
        for(var j = 0 ; j <= iHIDDEN; j++){
            d_w2[j2][j] = d_w2[j2][j] + delta2[j2] * v[j];
        }
    }
    for(var j = 0 ; j < iHIDDEN; j++){
        for(var k = 0 ; k <= iINPUT; k++){
            d_w1[j][k] = d_w1[j][k] + delta1[j] * xi[k];
        }
    }
        return null;
    }
    function calc_error() {
        var E = 0.0;
        for(var i = 0 ; i < iOUTPUT; i++){
            E = E + (zeta[i] - o[i]) * (zeta[i] - o[i]);
        }
        return E;
    }
    function w_modify() {
        for(var j2 = 0 ; j2 < iOUTPUT; j2++){
            for(var j = 0 ; j <= iHIDDEN; j++){
                d_w2[j2][j] = dALPHA * dETA * d_w2[j2][j] + dALPHA * pre_dw2[j2][j];
                w2[j2][j] = w2[j2][j] + d_w2[j2][j];
            }
        }
        for(var j = 0 ; j < iHIDDEN; j++){
            for(var k = 0 ; k <= iINPUT; k++){
                d_w1[j][k] = dALPHA * dETA * d_w1[j][k] + dALPHA * pre_dw1[j][k];
                w1[j][k] = w1[j][k] + d_w1[j][k];
            }
        }
        return null;
    }
    function log_out(s) {
        document.getElementById('result1').innerText =
            s + document.getElementById('result1').innerText;
    }
    function arand() {
        var r = Math.random();
        r = r * 2 * dW0 - dW0;
        return r;
    }
    function sigmoid(u) {
        return  1.0 / (1.0 +  Math.exp(-dBETA * u));
    }
}

※ソースを一部修正
function w_modify()
for(var j = 0 ; j < iOUTPUT; j++){
for(var j = 0 ; j < iHIDDEN; j++){