LoginSignup
3
2

More than 5 years have passed since last update.

Javascript で ニューラルネット(XOR回路を作ってみる)

Last updated at Posted at 2018-01-07

ちょっと前に書いたコードのメモ:

XOR(排他的論理和)を、ニューラルネットに学習させる例

<XOR>
True と True で False  (1 & 1 -> 0)
True と False で True   (1 & 0 -> 1)
False と True で True   (0 & 1 -> 1)
False と False で False  (0 & 0 -> 0)

結果
3/0(0),0(1),_0.009335834287899665(0),
2/0(0),1(1),_0.9981023530577846(0),
1/1(0),1(1),_0.010547904797307581(0),
0/1(0),0(1),_0.9859294247263731(0),

<結果の見方> ・・・ XOR回路が学習できている
0,0 →  0.00933
0,1 →  0.99810
1,1 →  0.01054
1,0 →  0.98592

myjs.html
<!doctype html>
<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
<link rel="stylesheet" type="text/css" charset="utf-8" href="myjs.css">
<script type="text/javascript" charset="utf-8" src="./myjs.js"></script>
<script type="text/javascript">
onload = function() {
    document.getElementById('program1').innerText = program1;
    program1();
    }
</script>
<title>ニューラルネットワークをJavaScriptで</title>
</head>
<body>
<h3>ソースコード</h3>
<pre class="source200" id="program1"></pre>
<h3>結果</h3>
<pre class="source100" id="result1"></pre>
</body>
</html>

myjs.css
@charset 'utf-8';
body {
}
pre {
    border: 1px solid #000000;
    overflow: scroll;
    line-height: 100%;
}
pre.source200 {
    max-height: 20em;
}
pre.source100 {
    max-height: 10em;
}
myjs.js
function program1() {
    var iINPUT = 2;
    var iHIDDEN = 8;
    var iOUTPUT = 1;
    var iPR = 1000;
    var iMAX_T = 500000;
    var dETA = 2.5;
    var dEPS = 0.0001;
    var dALPHA = 0.92;
    var dBETA = 0.35;
    var dW0 = 0.9;
    var xi = new Array(iINPUT);
    var v = new Array(iHIDDEN);
    var o = new Array(iOUTPUT-1);
    var zeta = new Array(iOUTPUT-1);

    //var w1;// = new Array();//(iHIDDEN - 1, iINPUT);
    //var w2;// = new Array();//(iOUTPUT - 1, iHIDDEN);
    var w1 = new Array();
    for(var a1=0;a1 < iHIDDEN;a1++){w1[a1] = new Array();for(var a2=0;a2 <= iINPUT;a2++){w1[a1][a2] = 0;}}
    var w2 = new Array();
    for(var a1=0;a1 < iOUTPUT;a1++){w2[a1] = new Array();for(var a2=0;a2 <= iHIDDEN;a2++){w2[a1][a2] = 0;}}
    //var d_w1 = new Array();//(iHIDDEN - 1, iINPUT);
    //var d_w2 = new Array();//(iOUTPUT - 1, iHIDDEN);
    var d_w1 = new Array();
    for(var a1=0;a1 < iHIDDEN;a1++){d_w1[a1] = new Array();for(var a2=0;a2 <= iINPUT;a2++){d_w1[a1][a2] = 0;}}
    var d_w2 = new Array();
    for(var a1=0;a1 < iOUTPUT;a1++){d_w2[a1] = new Array();for(var a2=0;a2 <= iHIDDEN;a2++){d_w2[a1][a2] = 0;}}
    var pre_dw1 = new Array();//(iHIDDEN - 1, iINPUT);
    var pre_dw2 = new Array();//(iOUTPUT - 1, iHIDDEN);
    var pre_dw1 = new Array();
    for(var a1=0;a1 < iHIDDEN;a1++){pre_dw1[a1] = new Array();for(var a2=0;a2 <= iINPUT;a2++){pre_dw1[a1][a2] = 0;}}
    var pre_dw2 = new Array();
    for(var a1=0;a1 < iOUTPUT;a1++){pre_dw2[a1] = new Array();for(var a2=0;a2 <= iHIDDEN;a2++){pre_dw2[a1][a2] = 0;}}
    var data;// = new Array();   //iPATTERNz - 1, iINPUT - 1
    var d_data;// = new Array(); //iPATTERNz - 1, iOUTPUT - 1
    var t_data;// = new Array(); //iPATTERNo - 1, iINPUT - 1
    var iPATTERNz =0;
    var iPATTERNo =0;
    log_out('Start\n');
    (function main() {
        load_data();

        back_propagation_main();
        tryTest();
    })();
    log_out('End\n');
    function load_data() {
        iPATTERNz = 8
        iPATTERNo = 4
        var p =0;
        var ii =0;  //Inputの数
        var iz =0;  //Zeta=Outputの数
        var id =0;

        //入力する値
        data = [[0,0],[0,1],[1,0],[1,1],[1,0],[1,1],[0,1],[0,0]];

        //教師信号(所謂、模範解答)
        t_data =[[0],[1],[1],[0],[1],[0],[1],[0]];
        //お試し用の問題
        d_data =[[1,0],[1,1],[0,1],[0,0]];
        return null;
    }
    function back_propagation_main() {
    var t = 0;
    var p = 0;
    var E = 0.0;
    var Esum = 0.0;
        w_init();

        for (t = 0; t < iMAX_T; t++) {

            dw_init();
            Esum = 0;

            for (p = 0; p < iPATTERNz; p++) {
                xi_set(p);
                forward();
                backward();
                Esum = Esum + calc_error();
            }

            w_modify();

            E = Esum / (iOUTPUT * iPATTERNz);

            if( t % iPR == 0 ) {
                log_out( t + '/' + iMAX_T + ' ' + E + '\n');
            }
            if(E < dEPS){
                break;
            }
        }


        return null;
    }
    function tryTest() {
        for(var p=0; p<iPATTERNo; p++){
            var s='';
            s = s + p + '/';
            for(var k=0; k<iINPUT; k++){
                xi[k] = d_data[p][k];
                s = s + xi[k] +'(' + k + '),';
            }
            s = s + '_';
            forward();
            for(var j2=0;j2<iOUTPUT;j2++){
                s = s + o[j2] +'(' + j2 + '),';
            }
            s = s + '\n';
            log_out(s);
        }
    }
    function w_init() {

    for(var j = 0; j < iHIDDEN; j++){
        for(var k = 0; k <= iINPUT; k++){
        w1[j][k] = arand();
        d_w1[j][k] = 0.0;
        }
    }

    for(var j2 = 0 ; j2 < iOUTPUT ; j2++){
        for(var j = 0 ; j <= iHIDDEN ; j++){
        w2[j2][j] = arand();
        d_w2[j2][j] = 0.0;
        }
    }
        return null;
    }
    function dw_init() {
    for(var j = 0 ; j < iHIDDEN;j++){
        for(var k = 0 ; k <= iINPUT; k++){
        pre_dw1[j][k] = d_w1[j][k]
        d_w1[j][k] = 0.0;
        }
    }
    for(var j2 = 0 ; j2 < iOUTPUT; j2++){
        for(var j = 0 ; j<= iHIDDEN; j++){
        pre_dw2[j2][j] = d_w2[j2][j]
        d_w2[j2][j] = 0.0;
        }
    }
        return null;
    }
    function xi_set(p) {
    for(var k = 0 ; k < iINPUT; k++){
        xi[k] = data[p][k];
    }
    xi[iINPUT] = 1.0;
    for(var i = 0 ; i < iOUTPUT; i++){
        zeta[i] = t_data[p][i];
    }
        return null;
    }
    function forward() {
    for(var j = 0 ; j < iHIDDEN; j++){
        var sum = 0;
        for(var k = 0 ; k <= iINPUT; k++){
            sum = sum + xi[k] * w1[j][k];
        }
        v[j] = sigmoid(sum);
    }
    v[iHIDDEN] = 1.0;
    for(var j2 = 0 ; j2 < iOUTPUT; j2++){
        var sum = 0;
        for(var j = 0 ; j <= iHIDDEN; j++){
            sum = sum + v[j] * w2[j2][j];
        }
        o[j2] = sigmoid(sum);
    }
        return null;
    }
    function backward() {
    var delta2 = new Array(iOUTPUT);
    var delta1 = new Array(iHIDDEN + 1);
    for(var i = 0 ; i < iOUTPUT; i++){
        delta2[i] = dBETA * o[i] * (1 - o[i]) * (zeta[i] - o[i]);
    }
    for(var j = 0 ; j < iHIDDEN; j++){
        var sum =0;
        for(var j2 = 0 ; j2 < iOUTPUT; j2++){
            sum = sum + w2[j2][j] * delta2[j2];
            delta1[j] = dBETA * v[j] * (1 - v[j]) * sum;
        }
    }
    for(var j2 = 0 ; j2 < iOUTPUT; j2++){
        for(var j = 0 ; j <= iHIDDEN; j++){
            d_w2[j2][j] = d_w2[j2][j] + delta2[j2] * v[j];
        }
    }
    for(var j = 0 ; j < iHIDDEN; j++){
        for(var k = 0 ; k <= iINPUT; k++){
            d_w1[j][k] = d_w1[j][k] + delta1[j] * xi[k];
        }
    }
        return null;
    }
    function calc_error() {
        var E = 0.0;
        for(var i = 0 ; i < iOUTPUT; i++){
            E = E + (zeta[i] - o[i]) * (zeta[i] - o[i]);
        }
        return E;
    }
    function w_modify() {
        for(var j2 = 0 ; j2 < iOUTPUT; j2++){
            for(var j = 0 ; j <= iHIDDEN; j++){
                d_w2[j2][j] = dALPHA * dETA * d_w2[j2][j] + dALPHA * pre_dw2[j2][j];
                w2[j2][j] = w2[j2][j] + d_w2[j2][j];
            }
        }
        for(var j = 0 ; j < iHIDDEN; j++){
            for(var k = 0 ; k <= iINPUT; k++){
                d_w1[j][k] = dALPHA * dETA * d_w1[j][k] + dALPHA * pre_dw1[j][k];
                w1[j][k] = w1[j][k] + d_w1[j][k];
            }
        }
        return null;
    }
    function log_out(s) {
        document.getElementById('result1').innerText =
            s + document.getElementById('result1').innerText;
    }
    function arand() {
        var r = Math.random();
        r = r * 2 * dW0 - dW0;
        return r;
    }
    function sigmoid(u) {
        return  1.0 / (1.0 +  Math.exp(-dBETA * u));
    }
}

※ソースを一部修正
function w_modify()
for(var j = 0 ; j < iOUTPUT; j++){
for(var j = 0 ; j < iHIDDEN; j++){

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2