LoginSignup
12
12

More than 5 years have passed since last update.

ブロック崩しを人工知能で攻略する(未完成)

Last updated at Posted at 2015-04-03

はじめに

googleのDQNに魅力を感じたので,人工知能を作成しています。
DQNがブロック崩しを攻略する動画
具体的にやりたいことは動画と同じように,ブロック崩しを攻略する人工知能の実装です。
※プログラミングも人工知能も本屋+ネットで勉強した程度なので,間違っているかもしれません。

進行状況

何万回も学習した結果がこれです↓
私の人工知能がブロック崩しを攻略?する動画
学習しているようですが,この後も学習を続けるとボールを追わなくなったりします。

実装コード

googleはdeep learningを使っていますが,ほとんど理解できていないので,代用品として単純なニューラルネットワークで実装しています。


<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<title>ブロック崩し - ランダム</title>
<meta http-equiv="Content-Style-Type" content="text/css">
<meta http-equiv="Content-Script-Type" content="text/javascript">
<meta name="robots" content="noindex,nofollow">
<style type="text/css">

body{ -moz-user-select: none; -webkit-user-select: none; -ms-user-select: none; font-size: 20px;}
canvas{ border: none; position: absolute; left: 10px; top: 10px;}
#graph2{ cursor: pointer;}
#data{ position: relative; top: 410px;}
#myPercent{ width: 40px; font-size: inherit;}
#quickCalc{ font-size: inherit;}


</style>
</head>
<body>

<canvas id='graph'></canvas>
<canvas id='graph2'></canvas>
<p id='data'>
    <input type='checkbox' id='resetEverytime' checked>落ちたらリセット<br>
    <input type='text' value='100' id='myPercent'> 
    <input type='button' id='quickCalc' value='描画せずに計算'> 
    <span id='myResult'></span>
</p>

<script>
(function(){

var rand;
    //バーを動かす関数
    function moveBar(per){
        //p.x = ボールのx座標, p.y = ボールのy座標, p.r = ボールの半径
        //p.vx = ボールの速度のx軸成分, p.vy = ボールの速度のy軸成分
        //b.x = 残っているブロックのx座標の配列, b.y = 残っているブロックのy座標の配列

        if(Math.floor(Math.random()*100)<per){
            rand = Math.floor(Math.random()*3); //0 or 1 or 2
        }else{
            rand = argMaxQ(p.x,p.y,bar.x);
        };

            var tempX = bar.x+(rand-1)*30; //左右に30ピクセルまたは動かない

            if (tempX<p.r) tempX = p.r; //左の壁に当たったら、それ以上左へは行かない
            else if (tempX>w-p.r) tempX = w-p.r; //右の壁に当たったらそれ以上右へは行かない
            bar.x = tempX;
    }


    function d(ID){ return document.getElementById(ID);}
    var timer;
    var paused = false, gameOver = false, gameEnded = true;
    var numberOfBalls = 0; //使用したボールの数
    var numberOfFrames = 0; //かかったフレーム数(÷33=時間(秒))
    var graph = d('graph'), graph2 = d('graph2');
    var g = graph.getContext('2d'), g2 = graph2.getContext('2d');
    var w = 600, h = 400;
    graph.width = graph2.width = w; graph.height = graph2.height = h;
    g.fillStyle = '#000';
    g.fillRect(0,0,w,h);
    g.save();
    g.font = '30px Arial';
    g.fillStyle = 'white';
    g.fillText('Click to Start',210,220);
    g.restore();

    // pはボール。順にx座標、y座標、半径、速度x成分、速度y成分,色
    var p = {x: 300, y: 20, r: 10, vx: 0, vy: 0, color: '#fff'};

    //barは棒。x座標、yは厚み、Lは長さの半分、edgeは角度が付き始める点の端からの距離
    var bar = {x: 300, y: 10, L: 50, edge: 45, color: '#ff0'};

    //bはブロック。b.xとb.yがともに配列で、xy座標を示す。wはwidthで幅、hはheightで高さ。全て同じ大きさの長方形
    var b = {x: [], y: [], w: 20, h: 10, color: '#88f'};

    var score = 0;

    d('graph2').addEventListener('click',function(){
        if (gameEnded){
            gameEnded = false;
            reset();
            drawBG();
            mainLoop();
        }
        else if (paused){
            paused = false;
            drawBG();
            mainLoop();
        }
        else{
            paused = true;
            clearTimeout(timer);
            g2.font = '30px Arial';
            g2.fillStyle = 'white';
            g2.fillText('paused',250,220);
        }
    },false);

    function reset(){
        bar.x = 300;
        p.x = bar.x; p.y = bar.y+p.r;
        var angle = Math.PI*(30)/180;//Math.PI*(Math.random()*60-30)/180; //-30~30
        p.vx = Math.sin(angle)*12; p.vy = Math.cos(angle)*12; //cosとsinの反転に注意
        b.x = []; b.y = [];
        for (var i=0; i<3; i++) for (var j=0; j<11; j++){
            b.x.push(w/2-10*b.w+j*b.w*2);
            b.y.push(340-i*b.h*2);
        }
        score = numberOfBalls = numberOfFrames = 0;
    }
    function mainLoop(){
        g.fillStyle = '#000';
        g.fillRect(0,0,w,h);
        g.setTransform(1,0,0,-1,0,h);

        //ボールの描画
        g.beginPath();
        g.arc(p.x,p.y,p.r,0,2*Math.PI);
        g.fillStyle = p.color;
        g.fill();
        //ボールの描画

        //バーの描画
        g.fillStyle = bar.color;
        g.fillRect(bar.x-bar.L+p.r+3,0,(bar.L-p.r-3)*2,bar.y); //±3は"遊び"

        //ボールを現在の速度をもとに移動させる
        p.x += p.vx;
        p.y += p.vy;

        //バーを動かす
        moveBar(0);

        //ここから衝突判定。まずはブロック
        if (p.y>=280){
            for (var i=0,I=b.x.length,hit=false; i<I; i++){
                if ((p.y-b.y[i])*p.vy<0 && Math.abs(p.x-b.x[i])<=b.w && Math.abs(p.y-b.y[i])<=b.h+p.r){
                    p.vy *= -1;
                    hit = true;
                    break;
                }
                else if ((p.x-b.x[i])*p.vx<0 &&Math.abs(p.y-b.y[i])<=b.h && Math.abs(p.x-b.x[i])<=b.w+p.r){
                    p.vx *= -1;
                    hit = true;
                    break;
                }
            }
            if (hit){
                score += 10;
                b.x.splice(i,1); b.y.splice(i,1);
                drawBG();
                if (b.x.length<=0){
                    gameEnded = true;
                    d('graph2').style.cursor = 'pointer';
                    g2.font = '40px Arial';
                    g2.fillStyle = 'white';
                    g2.fillText('SUCCESS !',200,200);
                    var timeLapse = Math.floor(100*numberOfFrames/33)/100;
                    g2.fillText(timeLapse+' sec',220,250);
                    return;
                }
            }
        }
        //↓衝突判定。ここから壁
        if (p.x>w-p.r){ p.x = w-p.r; p.vx *= -1;} //right
        if (p.x<p.r){   p.x = p.r; p.vx *= -1;} //left
        if (p.y>h-p.r){ p.y = h-p.r; p.vy *= -1;} //up
        if (p.y<p.r+bar.y){
            var p_bar = Math.abs(p.x-bar.x);
            if (!gameOver && p_bar<=bar.L){
                var X = (p.x>bar.x) ? 1 : -1; //衝突点のバーの法線ベクトル(X,1);
                if (p_bar<=bar.L-bar.edge) X = 0; //バーの中央
                else if (p_bar<=bar.L-bar.edge/3) X *= (p_bar-bar.L+bar.edge)/100; //0~0.3(バーの端は約73°)
                else X *= 0.3;
                var L = Math.sqrt(X*X+1); //法線ベクトルの長さ
                var vec = {x: X/L, y: 1/L}; //法線ベクトルの正規化
                var dot = vec.x*p.vx+vec.y*p.vy;
                p.y = p.r+bar.y;
                p.vx -= 2*dot*vec.x;
                p.vy -= 2*dot*vec.y;
                if (Math.abs(p.vx)/p.vy>1.5){ //角度が付き過ぎないよう調整
                    var v2 = p.vx*p.vx+p.vy*p.vy;
                    p.vy = Math.sqrt(v2/3.25);
                    p.vx = (p.vx<0) ? -p.vy*1.5 : p.vy*1.5;
                }
            }
            else if (!gameOver && p.y<p.r) gameOver = true;
            else if (gameOver && p.y<-p.r){
                gameOver = false;
                if (d('resetEverytime').checked) reset();
                else{
                    p.x = bar.x; p.y = bar.y+p.r;
                    var angle = Math.PI*(30)/180;//Math.PI*(Math.random()*60-30)/180; //-30~30
                    p.vx = Math.sin(angle)*12; p.vy = Math.cos(angle)*12; //cosとsinの反転に注意
                    numberOfBalls++;
                }
                drawBG();
            }
        }
        numberOfFrames++;
        timer = setTimeout(mainLoop,30); //30ミリ秒後にmainloop実行 (-> 1秒に約30回繰り返す)
    }
    //↓ブロックの描画。動きがないのに毎回描くのは無駄なので、別にして処理を節約。
    function drawBG(){
        g2.clearRect(0,0,w,h);
        g2.font = '20px Arial';
        g2.fillStyle = '#fff';
        g2.fillText('Balls: '+numberOfBalls,15,25);
        g2.fillText('Score: '+score,480,25);
        g2.save();
        g2.setTransform(1,0,0,-1,0,h);
        g2.fillStyle = '#ccf';
        for (var i=0,I=b.x.length; i<I; i++) g2.fillRect(b.x[i]-b.w,b.y[i]-b.h,2*b.w,2*b.h);
        g2.fillStyle = b.color;
        for (var i=0; i<I; i++) g2.fillRect(b.x[i]-b.w+1,b.y[i]-b.h+1,2*b.w-2,2*b.h-2);
        g2.restore();
    };

    //バイアス含めた(4,33,3)のneuralnet
    var inputN = new Array(3);
    inputN.push(-1);
    var numberI = inputN.length;

    var hiddenN = new Array(32);
    hiddenN.push(-1);
    var numberH = hiddenN.length;

    var inputW = new Array(numberI);
    for(var i=0;i<numberI;i++){
        inputW[i] = new Array(numberH-1);
        for(var j=0;j<numberH-1;j++){
            inputW[i][j] = Math.random() - 0.50;
        };
    };

    var outputN = new Array(3);
    var numberO = outputN.length;

    var hiddenW = new Array(numberH);
    for(var i=0;i<numberH;i++){
        hiddenW[i] = new Array(numberO);
        for(var j=0;j<numberO;j++){
            hiddenW[i][j] = Math.random() - 0.50;
        };
    };

    const GAIN = 1.00;
    function sigmoid(x){
        return 1.00/(1.00 + Math.exp(-GAIN*x));
    };

    function Q(x,y,z,a){
        inputN = [x/600,y/400,z/600,-1];

        for(var i=0;i<numberH-1;i++){
                        var suminput=0.0;//←訂正箇所
            for(var j=0;j<numberI;j++){
                suminput += inputW[j][i]*inputN[j];
            };
            hiddenN[i] = sigmoid(suminput);
        };

        for(var i=0;i<numberO;i++){
                        var sumhidden=0.0;//←訂正箇所
            for(var j=0;j<numberH;j++){
                sumhidden += hiddenW[j][i]*hiddenN[j];
            };
            outputN[i] = sigmoid(sumhidden);
        };
        return outputN[a];
    }
    function argMaxQ(x,y,z){
        var myQ = [Q(x,y,z,0),Q(x,y,z,1),Q(x,y,z,2)];
        if(myQ[1] > myQ[0]){
            if(myQ[2]>myQ[1]){
                return 2;               
            }else{
                return 1;
            };
        }else if(myQ[0] > myQ[1]){
            if(myQ[2]>myQ[0]){
                return 2;
            }else{
                return 0;
            };
        }else{
            if(myQ[2]>myQ[0]){
                return 2;
            }else{
                return Math.floor(Math.random()*3);
            };
        };
    };
    console.log("start");
    var NewQ;
    var myPercent;

    //↓ここから計算用
    d('quickCalc').addEventListener('click',function(){
        highscore = 0;
        myPercent = document.getElementById('myPercent').value-0;
        if (isNaN(myPercent) || myPercent>100 || myPercent<0) myPercent = 100;
        for (var i=0;i<5000;i++){
            reset(); //リセットしてから
            calculate(); //計算する
        };

        d('myResult').innerHTML = 'スコア: '+score+', フレーム数: '+numberOfFrames+', 失敗した回数: '+numberOfBalls;
        reset(); //念のためリセット
    },false);
    function calculate(){
        //更新式パラの初期値
        const ALPHA = 0.1;
        const GANNMA = 0.9;
        var reward;

        while (numberOfFrames<1000){


            var oldp = {x:p.x,y:p.y};
            var oldbar = bar.x;
            p.x += p.vx;
            p.y += p.vy;
            moveBar(myPercent);
            reward = 0;

            if (p.y>=280){
                for (var i=0,I=b.x.length,hit=false; i<I; i++){
                    if ((p.y-b.y[i])*p.vy<0 && Math.abs(p.x-b.x[i])<=b.w && Math.abs(p.y-b.y[i])<=b.h+p.r){
                        p.vy *= -1; hit = true; break;}
                    else if ((p.x-b.x[i])*p.vx<0 &&Math.abs(p.y-b.y[i])<=b.h && Math.abs(p.x-b.x[i])<=b.w+p.r){
                        p.vx *= -1; hit = true; break;}
                }
                if (hit){
                    score += 10;
                    if(score>40){
                        reward = 1;
                    }
                    b.x.splice(i,1); b.y.splice(i,1);
                    if (b.x.length==0) break;
                }
            }
            if (p.x>w-p.r){ p.x = w-p.r; p.vx *= -1;} //right
            if (p.x<p.r){   p.x = p.r; p.vx *= -1;} //left
            if (p.y>h-p.r){ p.y = h-p.r; p.vy *= -1;} //up
            if (p.y<p.r+bar.y){
                var p_bar = Math.abs(p.x-bar.x);
                if (p_bar<=bar.L){
                    var X = (p.x>bar.x) ? 1 : -1; //衝突点のバーの法線ベクトル(X,1);
                    if (p_bar<=bar.L-bar.edge) X = 0; //バーの中央
                    else if (p_bar<=bar.L-bar.edge/3) X *= (p_bar-bar.L+bar.edge)/100; //0~0.3(バーの端は約73°)
                    else X *= 0.3;
                    var L = Math.sqrt(X*X+1); //法線ベクトルの長さ
                    var vec = {x: X/L, y: 1/L}; //法線ベクトルの正規化
                    var dot = vec.x*p.vx+vec.y*p.vy;
                    p.y = p.r+bar.y;
                    p.vx -= 2*dot*vec.x;
                    p.vy -= 2*dot*vec.y;
                    if (Math.abs(p.vx)/p.vy>1.5){ //角度が付き過ぎないよう調整
                        var v2 = p.vx*p.vx+p.vy*p.vy;
                        p.vy = Math.sqrt(v2/3.25);
                        p.vx = (p.vx<0) ? -p.vy*1.5 : p.vy*1.5;
                    }
                    reward = 1;
                }else if (p.y<p.r){
                    reward = -1;
                    break;  
                };
            };
            var myQ = [Q(p.x,p.y,bar.x,0),Q(p.x,p.y,bar.x,1),Q(p.x,p.y,bar.x,2)];
            var oldQ = Q(oldp.x,oldp.y,oldbar,rand);
            NewQ = oldQ + ALPHA*(reward + GANNMA*Math.max(myQ[0],myQ[1],myQ[2]) - oldQ);
            if(NewQ>1){
                NewQ = 1;
            }else if(NewQ<0){
                NewQ = 0;
            };
            train(rand);
            numberOfFrames++;
        };
    };

    function train(oNum){
        const ETA = 0.4;

        var errorOut;
        var errorHid = new Array(numberH);

        //出力層の誤差
        errorOut = (outputN[oNum] - NewQ)*outputN[oNum]*(1.0 - outputN[oNum]);

        for (var i=0;i<numberH;i++){
            //隠れ層の誤差
            errorHid[i] = hiddenW[i][oNum]*errorOut*hiddenN[i]*(1.0 - hiddenN[i]);

            //重みの更新
            hiddenW[i][oNum] -= ETA*errorOut*hiddenN[i];
        };
        //誤差逆伝播法
        for(var i=0;i<numberI;i++){
            for(var j=0;j<numberH-1;j++){
                inputW[i][j] -= ETA*errorHid[j]*inputN[i];
            };
        };
    };

})();
</script>
</body>
</html>

コードの流れ

ブロック崩しの部分は友達が作ってくれたので,その部分のコードの意味はあまりわかりません。

バーを動かす関数

    function moveBar(per){
        if(Math.floor(Math.random()*100)<per){
//ランダムに動かす
            rand = Math.floor(Math.random()*3); //0 or 1 or 2
        }else{
//Q値が大きい方に動かす
            rand = argMaxQ(p.x,p.y,bar.x);
        };

            var tempX = bar.x+(rand-1)*30; //左右に30ピクセルまたは動かない

            if (tempX<p.r) tempX = p.r; //左の壁に当たったら、それ以上左へは行かない
            else if (tempX>w-p.r) tempX = w-p.r; //右の壁に当たったらそれ以上右へは行かない
            bar.x = tempX;
    }

以下人工知能部分について説明します。

ニューラルネットワークをつかってQ値を出す。

    //バイアス含めた(4,33,3)のneuralnet
    var inputN = new Array(3);
    inputN.push(-1);
    var numberI = inputN.length;

    var hiddenN = new Array(32);
    hiddenN.push(-1);
    var numberH = hiddenN.length;

    var inputW = new Array(numberI);
    for(var i=0;i<numberI;i++){
        inputW[i] = new Array(numberH-1);
        for(var j=0;j<numberH-1;j++){
            inputW[i][j] = Math.random() - 0.50;
        };
    };

    var outputN = new Array(3);
    var numberO = outputN.length;

    var hiddenW = new Array(numberH);
    for(var i=0;i<numberH;i++){
        hiddenW[i] = new Array(numberO);
        for(var j=0;j<numberO;j++){
            hiddenW[i][j] = Math.random() - 0.50;
        };
    };
    //シグモイド関数
    const GAIN = 1.00;
    function sigmoid(x){
        return 1.00/(1.00 + Math.exp(-GAIN*x));
    };
    //Q値 Q(ボールのx座標,ボールのy座標,バーのx座標,行動)
    function Q(x,y,z,a){
        inputN = [x/600,y/400,z/600,-1];

        for(var i=0;i<numberH-1;i++){
             var suminput=0.0;
            for(var j=0;j<numberI;j++){
                suminput += inputW[j][i]*inputN[j];
            };
            hiddenN[i] = sigmoid(suminput);
        };

        for(var i=0;i<numberO;i++){
                        var sumhidden=0.0;
            for(var j=0;j<numberH;j++){
                sumhidden += hiddenW[j][i]*hiddenN[j];
            };
            outputN[i] = sigmoid(sumhidden);
        };
        return outputN[a];
    }
//Q(x,y,z,0):左に動く価値,Q(x,y,z,1):止まる価値,Q(x,y,z,2):右に動く価値

そもそもこのやり方であっているかもわかりません。

更新式を使ってQ値の更新を行う


const ALPHA = 0.1;
const GANNMA = 0.9;
//Q(ボールのx座標,ボールのy座標,バーのx座標,行動):例えば行動=0なら左
var myQ = [Q(p.x,p.y,bar.x,0),Q(p.x,p.y,bar.x,1),Q(p.x,p.y,bar.x,2)];
var oldQ = Q(oldp.x,oldp.y,oldbar,rand);

//Q値の更新
NewQ = oldQ + ALPHA*(reward + GANNMA*Math.max(myQ[0],myQ[1],myQ[2]) - oldQ);
if(NewQ>1){
    NewQ = 1;
}else if(NewQ<0){
    NewQ = 0;
};

//学習を行う
train(rand);

Q値の更新はあっていると思いますが,0〜1に収まるように処理している部分は間違っているかもしれません。ニューラルネットワークの出力値が 0〜1なのでQ値の更新式も合わせた方がいいのではないかと思って付け加えました。

学習

ニューラルネットワークの重みを更新します。
勾配降下法と誤差逆伝播法を自分なりに解釈して実装しています。

    function train(oNum){
        const ETA = 0.1;

        var errorOut;
        var errorHid = new Array(numberH);

        //出力層の誤差
        errorOut = (outputN[oNum] - NewQ)*outputN[oNum]*(1.0 - outputN[oNum]);

        for (var i=0;i<numberH;i++){
            //隠れ層の誤差
            errorHid[i] = hiddenW[i][oNum]*errorOut*hiddenN[i]*(1.0 - hiddenN[i]);

            //重みの更新
            hiddenW[i][oNum] -= ETA*errorOut*hiddenN[i];
        };
        //誤差逆伝播法
        for(var i=0;i<numberI;i++){
            for(var j=0;j<numberH-1;j++){
                inputW[i][j] -= ETA*errorHid[j]*inputN[i];
            };
        };
    };

ニューラルネットワークの出力は3つのQ値を出していますが,実際に使うのは一つなので,誤差errorOutが一つしかありません。この辺りも間違っているかもしれません。

課題

どのように進めたらいいかわからなくなり,行き詰まりを感じています。
うまくいかない原因として
・Q-learningやニューラルネットワークのやり方が間違っている
・googleのようにdeep learningにしなければいけない
・ニューロンの数やパラメータの値が違う
があげられます。

どれをとってももっと勉強しろって感じなんですが,この辺の実装資料がなかなか見つからないので私の学習速度も低下しています。とりあえず地道に実装資料検索+基礎勉強を続けます。

参考資料

DQNの論文
強化学習(Q-learning)
ニューラルネットワーク
ニューラルネットワークを用いたQ-learning

最後に

このプロジェクトは友達と二人で行っています。応援したい!や,一緒にやりたい!って人がいればコメントください。

12
12
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
12