LoginSignup
9
9

More than 5 years have passed since last update.

ブロック崩し攻略人工知能の学習確認テスト

Last updated at Posted at 2015-04-09

はじめに

ブロック崩しを攻略する人工知能を作っています。今回は私たちの作ったニューラルネットワーク+Q-learningがちゃんと学習しているかどうかテストします。

テスト方法

前回までの学習方法は
スコアを上昇させるか,ボールを跳ね返せば褒められ,ボールを落とせば罰を受ける
でした。今回は
ボールとバーの距離を縮めれば褒められ,離せば罰を受ける
とします。ボールを追いかけるようにバーを動けば学習成功です。
今回は行動に対して報酬がすぐに得られるので,学習もしやすいはずです。

結果動画

※YOUTUBEに飛びます
block3.jpg
私の人工知能がブロック崩しをする動画

見事に動いていますね。

考察やら,まとめやら,課題やら,,,

今回の結果で私たちの人工知能が学習していることはわかりました。ただし,望む人工知能ではありません。私が望むのは目標を教えれば,手段を自分で獲得する人工知能です。今回の人工知能は手段を教わって,ただそれを実行しただけです。googleのDQNはゲーム画面を元にスコアが上昇すれば褒められる学習方法です。つまり,スコアの上昇という目標を与えられ,自ら手段を獲得しています。
前回書いたように,今やれることは,Q値の更新を過去にも伝えることです。そうすれば報酬が得られた後,報酬を得るまでの行動に価値が生まれ,今回のような学習しやすい状況になるかもしれません。

実装コード

前回紹介したボルツマン選択を使っています。


<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<title>ブロック崩し - ランダム</title>
<meta http-equiv="Content-Style-Type" content="text/css">
<meta http-equiv="Content-Script-Type" content="text/javascript">
<meta name="robots" content="noindex,nofollow">
<style type="text/css">

body{ -moz-user-select: none; -webkit-user-select: none; -ms-user-select: none; font-size: 20px;}
canvas{ border: none; position: absolute; left: 10px; top: 10px;}
#graph2{ cursor: pointer;}
#data{ position: relative; top: 410px;}
#myPercent{ width: 40px; font-size: inherit;}
#quickCalc{ font-size: inherit;}


</style>
</head>
<body>

<canvas id='graph'></canvas>
<canvas id='graph2'></canvas>
<p id='data'>
    <input type='checkbox' id='resetEverytime' checked>落ちたらリセット<br>
    <input type='text' value='100' id='myPercent'> 
    <input type='button' id='quickCalc' value='描画せずに計算'> 
    <span id='myResult'></span>
</p>

<script>
(function(){

var rand;
    //バーを動かす関数
    function moveBar(T){
            rand =Boltzmann(p.x,p.y,bar.x,p.vx,p.vy,T);
            var tempX = bar.x+(rand-1)*30; //左右に10ピクセルまたは動かない

            if (tempX<p.r) tempX = p.r; //左の壁に当たったら、それ以上左へは行かない
            else if (tempX>w-p.r) tempX = w-p.r; //右の壁に当たったらそれ以上右へは行かない
            bar.x = tempX;
    }

    function d(ID){ return document.getElementById(ID);}
    var timer;
    var paused = false, gameOver = false, gameEnded = true;
    var numberOfBalls = 0; //使用したボールの数
    var numberOfFrames = 0; //かかったフレーム数(÷33=時間(秒))
    var graph = d('graph'), graph2 = d('graph2');
    var g = graph.getContext('2d'), g2 = graph2.getContext('2d');
    var w = 600, h = 400;
    graph.width = graph2.width = w; graph.height = graph2.height = h;
    g.fillStyle = '#000';
    g.fillRect(0,0,w,h);
    g.save();
    g.font = '30px Arial';
    g.fillStyle = 'white';
    g.fillText('Click to Start',210,220);
    g.restore();

    // pはボール。順にx座標、y座標、半径、速度x成分、速度y成分,色
    var p = {x: 300, y: 20, r: 10, vx: 0, vy: 0, color: '#fff'};

    //barは棒。x座標、yは厚み、Lは長さの半分、edgeは角度が付き始める点の端からの距離
    var bar = {x: 300, y: 10, L: 50, edge: 45, color: '#ff0'};

    //bはブロック。b.xとb.yがともに配列で、xy座標を示す。wはwidthで幅、hはheightで高さ。全て同じ大きさの長方形
    var b = {x: [], y: [], w: 20, h: 10, color: '#88f'};

    var score = 0;

    d('graph2').addEventListener('click',function(){
        if (gameEnded){
            gameEnded = false;
            reset();
            drawBG();
            mainLoop();
        }
        else if (paused){
            paused = false;
            drawBG();
            mainLoop();
        }
        else{
            paused = true;
            clearTimeout(timer);
            g2.font = '30px Arial';
            g2.fillStyle = 'white';
            g2.fillText('paused',250,220);
        }
    },false);

    function reset(){
        bar.x = 300;
        p.x = bar.x; p.y = bar.y+p.r;
        var angle = Math.PI*(10)/180;//Math.PI*(Math.random()*60-30)/180; //-30~30
        p.vx = Math.sin(angle)*12; p.vy = Math.cos(angle)*12; //cosとsinの反転に注意
        b.x = []; b.y = [];
        for (var i=0; i<3; i++) for (var j=0; j<11; j++){
            b.x.push(w/2-10*b.w+j*b.w*2);
            b.y.push(340-i*b.h*2);
        }
        score = numberOfBalls = numberOfFrames = 0;
    }
    function mainLoop(){
        g.fillStyle = '#000';
        g.fillRect(0,0,w,h);
        g.setTransform(1,0,0,-1,0,h);

        //ボールの描画
        g.beginPath();
        g.arc(p.x,p.y,p.r,0,2*Math.PI);
        g.fillStyle = p.color;
        g.fill();
        //ボールの描画

        //バーの描画
        g.fillStyle = bar.color;
        g.fillRect(bar.x-bar.L+p.r+3,0,(bar.L-p.r-3)*2,bar.y); //±3は"遊び"

        //ボールを現在の速度をもとに移動させる
        p.x += p.vx;
        p.y += p.vy;

        //バーを動かす
        moveBar(1.0/Math.log(Time+1.001));

        //ここから衝突判定。まずはブロック
        if (p.y>=280){
            for (var i=0,I=b.x.length,hit=false; i<I; i++){
                if ((p.y-b.y[i])*p.vy<0 && Math.abs(p.x-b.x[i])<=b.w && Math.abs(p.y-b.y[i])<=b.h+p.r){
                    p.vy *= -1;
                    hit = true;
                    break;
                }
                else if ((p.x-b.x[i])*p.vx<0 &&Math.abs(p.y-b.y[i])<=b.h && Math.abs(p.x-b.x[i])<=b.w+p.r){
                    p.vx *= -1;
                    hit = true;
                    break;
                }
            }
            if (hit){
                score += 10;
                b.x.splice(i,1); b.y.splice(i,1);
                drawBG();
                if (b.x.length<=0){
                    gameEnded = true;
                    d('graph2').style.cursor = 'pointer';
                    g2.font = '40px Arial';
                    g2.fillStyle = 'white';
                    g2.fillText('SUCCESS !',200,200);
                    var timeLapse = Math.floor(100*numberOfFrames/33)/100;
                    g2.fillText(timeLapse+' sec',220,250);
                    return;
                }
            }
        }
        //↓衝突判定。ここから壁
        if (p.x>w-p.r){ p.x = w-p.r; p.vx *= -1;} //right
        if (p.x<p.r){   p.x = p.r; p.vx *= -1;} //left
        if (p.y>h-p.r){ p.y = h-p.r; p.vy *= -1;} //up
        if (p.y<p.r+bar.y){
            var p_bar = Math.abs(p.x-bar.x);
            if (!gameOver && p_bar<=bar.L){
                var X = (p.x>bar.x) ? 1 : -1; //衝突点のバーの法線ベクトル(X,1);
                if (p_bar<=bar.L-bar.edge) X = 0; //バーの中央
                else if (p_bar<=bar.L-bar.edge/3) X *= (p_bar-bar.L+bar.edge)/100; //0~0.3(バーの端は約73°)
                else X *= 0.3;
                var L = Math.sqrt(X*X+1); //法線ベクトルの長さ
                var vec = {x: X/L, y: 1/L}; //法線ベクトルの正規化
                var dot = vec.x*p.vx+vec.y*p.vy;
                p.y = p.r+bar.y;
                p.vx -= 2*dot*vec.x;
                p.vy -= 2*dot*vec.y;
                if (Math.abs(p.vx)/p.vy>1.5){ //角度が付き過ぎないよう調整
                    var v2 = p.vx*p.vx+p.vy*p.vy;
                    p.vy = Math.sqrt(v2/3.25);
                    p.vx = (p.vx<0) ? -p.vy*1.5 : p.vy*1.5;
                }
            }
            else if (!gameOver && p.y<p.r) gameOver = true;
            else if (gameOver && p.y<-p.r){
                gameOver = false;
                if (d('resetEverytime').checked) reset();
                else{
                    p.x = bar.x; p.y = bar.y+p.r;
                    var angle = Math.PI*(30)/180;//Math.PI*(Math.random()*60-30)/180; //-30~30
                    p.vx = Math.sin(angle)*12; p.vy = Math.cos(angle)*12; //cosとsinの反転に注意
                    numberOfBalls++;
                }
                drawBG();
            }
        }
        numberOfFrames++;
        timer = setTimeout(mainLoop,10); //30ミリ秒後にmainloop実行 (-> 1秒に約30回繰り返す)
    }
    //↓ブロックの描画。動きがないのに毎回描くのは無駄なので、別にして処理を節約。
    function drawBG(){
        g2.clearRect(0,0,w,h);
        g2.font = '20px Arial';
        g2.fillStyle = '#fff';
        g2.fillText('Balls: '+numberOfBalls,15,25);
        g2.fillText('Score: '+score,480,25);
        g2.save();
        g2.setTransform(1,0,0,-1,0,h);
        g2.fillStyle = '#ccf';
        for (var i=0,I=b.x.length; i<I; i++) g2.fillRect(b.x[i]-b.w,b.y[i]-b.h,2*b.w,2*b.h);
        g2.fillStyle = b.color;
        for (var i=0; i<I; i++) g2.fillRect(b.x[i]-b.w+1,b.y[i]-b.h+1,2*b.w-2,2*b.h-2);
        g2.restore();
    };

    //バイアス含めた(7,33,3)のneuralnet
    var inputN = new Array(5);
    //var inputN = new Array(3);
    inputN.push(-1);
    var numberI = inputN.length;

    var hiddenN = new Array(32);
    hiddenN.push(-1);
    var numberH = hiddenN.length;

    var inputW = new Array(numberI);
    for(var i=0;i<numberI;i++){
        inputW[i] = new Array(numberH-1);
        for(var j=0;j<numberH-1;j++){
            inputW[i][j] = 0;
        };
    };

    var outputN = new Array(3);
    var numberO = outputN.length;

    //初期の重みを0にすることで確率を均等にする。(sigmoid(0)=0.5よりoutputは全て0.5)
    var hiddenW = new Array(numberH);
    for(var i=0;i<numberH;i++){
        hiddenW[i] = new Array(numberO);
        for(var j=0;j<numberO;j++){
            hiddenW[i][j] = 0;
        };
    };

    const GAIN = 1.00;
    function sigmoid(x){
        return 1.00/(1.00 + Math.exp(-GAIN*x));
    };

    function Q(x,y,z,vx,vy,a){
        inputN = [x/600,y/400,z/600,vx/6,vx/11,-1];
        for(var i=0;i<numberH-1;i++){
            var suminput=0.0;
            for(var j=0;j<numberI;j++){
                suminput += inputW[j][i]*inputN[j];
            };
            hiddenN[i] = sigmoid(suminput);
        };

        for(var i=0;i<numberO;i++){
            var sumhidden=0.0;
            for(var j=0;j<numberH;j++){
                sumhidden += hiddenW[j][i]*hiddenN[j];
            };
            outputN[i] = sigmoid(sumhidden);
        };
        return outputN[a];
    };

    //ボルツマン選択
    function Boltzmann(x,y,z,vx,vy,T){
        var myQ = [Q(x,y,z,vx,vy,0),Q(x,y,z,vx,vy,1),Q(x,y,z,vx,vy,2)];
        var prob = new Array(3);
        for(var i=0;i<3;i++){
            prob[i] = Math.exp(myQ[i]/T)/(Math.exp(myQ[0]/T) + Math.exp(myQ[1]/T) + Math.exp(myQ[2]/T));
        };
        var select = Math.random();
        var sum = 0.0;
        for(var i=0;i<3;i++){
            sum +=prob[i]
            if (select<sum) {
                return i
            };
        };
    };

    console.log("start");
    var NewQ;
    var myPercent;
    var Time = 0;

    //↓ここから計算用
    d('quickCalc').addEventListener('click',function(){
        myPercent = document.getElementById('myPercent').value-0;
        if (isNaN(myPercent) || myPercent>100 || myPercent<0) myPercent = 100;
        for (var i=0;i<100;i++){
            reset(); //リセットしてから
            calculate(); //計算する
            console.log(Math.max.apply(null,hiddenW[0]));
        };

        d('myResult').innerHTML = 'スコア: '+score+', フレーム数: '+numberOfFrames+', 失敗した回数: '+numberOfBalls +', Time:' + Time;
        reset(); //念のためリセット
    },false);
    function calculate(){
        //更新式パラの初期値
        const ALPHA = 0.5;
        const GANNMA = 1.0;
        var reward;

        while (numberOfFrames<1000){


            var oldp = {x:p.x,y:p.y,vx:p.vx,vy:p.vy};
            var oldbar = bar.x;
            p.x += p.vx;
            p.y += p.vy;
            moveBar(1.0/Math.log(Time + 1.000010));             
            Time += 1;

            reward = 0;

            if (p.y>=280){
                for (var i=0,I=b.x.length,hit=false; i<I; i++){
                    if ((p.y-b.y[i])*p.vy<0 && Math.abs(p.x-b.x[i])<=b.w && Math.abs(p.y-b.y[i])<=b.h+p.r){
                        p.vy *= -1; hit = true; break;}
                    else if ((p.x-b.x[i])*p.vx<0 &&Math.abs(p.y-b.y[i])<=b.h && Math.abs(p.x-b.x[i])<=b.w+p.r){
                        p.vx *= -1; hit = true; break;}
                }
                if (hit){
                    score += 10;
                    reward = 2;
                    b.x.splice(i,1); b.y.splice(i,1);
                    if (b.x.length==0) break;
                }
            }
            if (p.x>w-p.r){ p.x = w-p.r; p.vx *= -1;} //right
            if (p.x<p.r){   p.x = p.r; p.vx *= -1;} //left
            if (p.y>h-p.r){ p.y = h-p.r; p.vy *= -1;} //up
            if (p.y<p.r+bar.y){
                var p_bar = Math.abs(p.x-bar.x);
                if (p_bar<=bar.L){
                    var X = (p.x>bar.x) ? 1 : -1; //衝突点のバーの法線ベクトル(X,1);
                    if (p_bar<=bar.L-bar.edge) X = 0; //バーの中央
                    else if (p_bar<=bar.L-bar.edge/3) X *= (p_bar-bar.L+bar.edge)/100; //0~0.3(バーの端は約73°)
                    else X *= 0.3;
                    var L = Math.sqrt(X*X+1); //法線ベクトルの長さ
                    var vec = {x: X/L, y: 1/L}; //法線ベクトルの正規化
                    var dot = vec.x*p.vx+vec.y*p.vy;
                    p.y = p.r+bar.y;
                    p.vx -= 2*dot*vec.x;
                    p.vy -= 2*dot*vec.y;
                    if (Math.abs(p.vx)/p.vy>1.5){ //角度が付き過ぎないよう調整
                        var v2 = p.vx*p.vx+p.vy*p.vy;
                        p.vy = Math.sqrt(v2/3.25);
                        p.vx = (p.vx<0) ? -p.vy*1.5 : p.vy*1.5;
                    }
                    reward = 10;
                }else if (p.y<p.r){
                    reward = -10;
                    break;  
                };
            };

            var nextQ = [Q(p.x,p.y,bar.x,p.vx,p.vy,0),Q(p.x,p.y,bar.x,p.vx,p.vy,1),Q(p.x,p.y,bar.x,p.vx,p.vy,2)];
            var oldQ = Q(oldp.x,oldp.y,oldbar,oldp.vx,oldp.vy,rand);
            if(reward==0 && Math.abs(oldp.x-oldbar)>Math.abs(p.x-bar.x)){
                reward = 1;
            }else if (reward ==0){
                reward = -1;
            };
            NewQ = oldQ + ALPHA*(reward + GANNMA*Math.max(nextQ[0],nextQ[1],nextQ[2]) - oldQ);
            if(NewQ>1){
                NewQ = 1;
            }else if(NewQ<0){
                NewQ = 0;
            };
            train(rand);
            numberOfFrames++;
        };
    };

    function train(oNum){
        const ETA = 0.3;

        var errorOut = 0;
        var errorHid = new Array(numberH);

        //出力層の誤差
        if(errorOut*(outputN[oNum] - NewQ)>0){
            errorOut = 2*(outputN[oNum] - NewQ)*outputN[oNum]*(1.0 - outputN[oNum])
        }else{
            errorOut = (outputN[oNum] - NewQ)*outputN[oNum]*(1.0 - outputN[oNum]);
        };

        for (var i=0;i<numberH;i++){
            //隠れ層の誤差
            errorHid[i] = hiddenW[i][oNum]*errorOut*hiddenN[i]*(1.0 - hiddenN[i]);

            //重みの更新
            hiddenW[i][oNum] -= ETA*errorOut*hiddenN[i];
        };
        //誤差逆伝播法
        for(var i=0;i<numberI;i++){
            for(var j=0;j<numberH-1;j++){
                inputW[i][j] -= ETA*errorHid[j]*inputN[i];
            };
        };
    };

})();
</script>
</body>
</html>

おまけ

遺伝的アルゴリズムでブロック崩しを攻略する方法も考えています。あっているかわかりませんが,ナップザック問題を解く遺伝的アルゴリズムを実装したのでまた書きます。
参考になる資料や,応援,共同研究希望などあればコメントください。

9
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
9