Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
40
Help us understand the problem. What are the problem?

More than 3 years have passed since last update.

【Unity ML-Agents】強化学習で物体を避ける

先日新バージョンがリリースされたUnity ML-Agents v0.3を使っていろいろ遊んでみました。
ML-Agents v0.3 Beta released: Imitation Learning, feedback-driven features, and more

新機能の紹介や変更点などは前回の記事をご覧ください。

SnapCrab_NoName_2018-3-25_15-13-46_No-00.png

今回は上から降ってくる物体を避けるというゲームを作ってみました。シューティングゲームみたいなもんです。

新バージョンの注意点

まずは新しくなったUnity ML-Agentsの使い方などを少し紹介していきたいと思います。
Adademy, Brain, Agent という基本構造などは変わりませんが、APIに関していろいろ変更されています。

TemplateAgent.cs
public class TemplateAgent : Agent {



    public override void CollectObservations()
    {

    }

    public override void AgentAction(float[] vectorAction, string textAction)
    {

    }

    public override void AgentReset()
    {

    }

    public override void AgentOnDone()
    {

    }
}

新バージョンではCollectObservations()に状態を書いていきます。
例えばサンプルにあるPushBlockを例にしてみます。

SnapCrab_NoName_2018-3-25_15-24-54_No-00.png

PushAgentBasic.cs
public override void CollectObservations()
    {
        float rayDistance = 12f;
        float[] rayAngles = { 0f, 45f, 90f, 135f, 180f, 110f, 70f };
        string[] detectableObjects;
        detectableObjects = new string[] { "block", "goal", "wall" };
        AddVectorObs(rayPer.Perceive(
            rayDistance, rayAngles, detectableObjects, 0f, 0f));
        AddVectorObs(rayPer.Perceive(
            rayDistance, rayAngles, detectableObjects, 1.5f, 0f));
    }

AddVectorObs()を使って状態を取得していきます。
また、あらたにRayPerceptionというスクリプトも追加され、これをAgentに張り付けて使います。

SnapCrab_NoName_2018-3-25_15-28-20_No-00.png

これはrayDistanceの距離,rayAnglesの方向に向けて、Physics.SphereCast()を使うことで、半径0.5の球体を飛ばして衝突判定を行いエージェントからオブジェクトまでの距離を取得しています。

RayPerception.gif

PushBlockの場合であれば、高さ0fと1.5fから7方向ずつ、計14本のレーザーで物体を検出します。

RayPerceptionを使った場合の状態数の計算式は以下の通りです。

(rayAngles) * (detectableObjects+2) * (rayPer.Perceive())

PushBlockでは 7*(3+2)*2で状態数は70になります。

SnapCrab_NoName_2018-3-25_15-42-50_No-00.png

しかし、なぜdetectableObjectsに+2するのかは謎。RayPerceptionスクリプト内にその記述があります。

※コメント欄より
どうやら、detectableObjectsとしてTagを渡したオブジェクトに当たったかどうかの当たり判定、そのオブジェクトまでの距離、の2点を情報として追加しているため+2しているようです。

シューティングゲームを作る

シューティングゲーム自体の実装はたくさん記事があると思うのでここでは割愛させていただきます。

SnapCrab_NoName_2018-3-25_15-52-38_No-00.png

私は上の画像のような感じで作りました。
上からボールが降ってきてぶつかるとゲームオーバーです。

SG_Play.gif

ShootingGameAgent.cs
public override void CollectObservations()
    {
        float rayDistance = 10f;
        float[] rayAngles = {0f,15f,30f,45f,60f,75f,
             90f,105f,120f,135f,150f,165f,
            180f,195f,210f,225f,240f,255f,
            270f,285f,300f,315f,330f,345f};
        string[] detectableObjects;
        detectableObjects = new string[] { "Enemy" };
        AddVectorObs(rayPer.Perceive(
            rayDistance, rayAngles, detectableObjects, 0f, 0f));
    }

状態はこんな感じで取得しています。状態数は72。
ちなみにdetectableObjectsはタグを利用します。ボールにEnemyというタグをつけています。

SnapCrab_NoName_2018-3-25_16-5-8_No-00.png

ゲームを実行して、右上にあるGizmosというボタンを押せば、レーザーがどのように照射されているか確認できます。

プレイヤーとボールが衝突したときは以下のような記述を書きました。

ShootingGameAgent.cs
void OnTriggerEnter(Collider col)
    {
        if (AirPlane.activeSelf == true)
        {
            if (col.gameObject.tag == "Enemy")
            {
                Instantiate(explosion, transform.position, Quaternion.identity);
                AddReward(-10f);
                AirPlane.SetActive(false);
                Done();
            }
        }
    }

プレイヤーがボールに衝突したとき、プレイヤーのアクティブをオフにし、爆発エフェクトを発生させます。
また、AddReward()を使うことで報酬を追加します。ボールにぶつかったら-10。
Done()は旧バージョンでのdone=trueです。これでAgentReset()を呼び出します。

ShootingGameAgent.cs
public override void AgentAction(float[] vectorAction, string textAction)
    {
        MoveAgent(vectorAction);

        AddReward(0.01f);
    }

ボールにぶつかっていない状態のときは、毎アクションごとに0.01の報酬を追加します。
MoveAgent()にはプレイヤーのアクションを書きましょう。ここは旧バージョンと変わりません。
私はこんな感じで書きました。

ShootingGameAgent.cs
public float speed;

public void MoveAgent(float[] act)
    {
        int action = Mathf.FloorToInt(act[0]);

        //アクション
        if (action==1)
        {
            transform.Translate(0, 0, speed);
        }
        if (action==2)
        {
            transform.Translate(-speed, 0, 0);
        }
        if (action==3)
        {
            transform.Translate(0, 0, -speed);
        }
        if (action==4)
        {
            transform.Translate(speed, 0, 0);
        }
        //移動制限
        if (transform.localPosition.x < -4.5f)
        {
            transform.localPosition =
                new Vector3(-4.5f, 0.5f, transform.localPosition.z);
        }
        if (4.5f < transform.localPosition.x)
        {
            transform.localPosition =
                new Vector3(4.5f, 0.5f, transform.localPosition.z);
        }
        if (transform.localPosition.z < -6f)
        {
            transform.localPosition =
                new Vector3(transform.localPosition.x, 0.5f, -6f);
        }
        if (6f < transform.localPosition.z)
        {
            transform.localPosition =
                new Vector3(transform.localPosition.x, 0.5f, 6f);
        }
    }

BrainTypeをPlayerにしてActionSizeを5、Elementにキーを割り当てます。

SnapCrab_NoName_2018-3-25_16-20-41_No-00.png

これでWASDでプレイヤーを操作できます。

あと、AcademyのTimeScaleの部分なんですが、ここで少しハマりました。

SnapCrab_NoName_2018-3-25_16-22-37_No-00.png

このタイムスケールは何倍の速度で強化学習をするのかを指定できます。数字を上げていけば10倍、100倍のスピードでゲームを進行させていくので単純に時間の節約になります。
ただ、FixdUpdate()のタイミングで時間を早めていますのでUpdate()に処理などを書いてタイムスケールを上げていくと挙動がおかしくなるので注意です。
タイムスケールが1なら問題ありませんが、もっと上げていく場合はUpdate()ではなくFixeUpdate()に処理を書きましょう。

あとはビルドして学習をしていきましょう。

SnapCrab_NoName_2018-3-25_16-29-25_No-00.png

ppo.pyはlearn.pyに変更され、ハイパーパラメータなどはtrainer_configファイルで指定するようになりました。

オプションなども少し変更されているので注意。

python learn.py ShootingGame --train --run-id=SG

結果

学習結果

5万回

SG_Internal_50000.gif

5万回でもすでにそこそこ避けています。

35万回

SG_Internal_350000.gif

35万回になるとほぼぶつかることはなくなりました。しかも、5万回のときに比べて、かなりギリギリ(小さい動き)でボールを避けているのがわかります。

バグ?

今回の実験ではAIは基本的に左下に陣取ってボールを避けています。しかし、なぜか一番右端のボールは見えていないのか避けることができません。100万回ぐらい学習させても、上の画像のようにほとんどは避けられるのですが、一番右端だけは避けられないのです。

何度か実験してみたところ、今回とは対照的に右下に陣取るときもありました。そのときもほとんどは避けられるのですが、今度は一番左端のボールが避けられなくなります。

Tensorflow内でどのように計算されているのか私にはわからないので原因は不明です。

ダウンロード

GitHubから今回作った環境をダウンロードできます。
My ML-Agents Game

5万回~35万回の学習済みデータも付属しています。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
40
Help us understand the problem. What are the problem?