LoginSignup
0
0

More than 1 year has passed since last update.

AIテニスプレイヤーとVRで試合しよう(MlAgents)

Last updated at Posted at 2023-03-05

目的 
MlAgentsを使用してAIにテニスを覚えさせる。自身の手をトラッキングしVR空間内で対戦する

使用したもの
MlAgents, Pico4, unity2020.3.2f1, steamVRPlugin

MlAgentsの仕組み

テニスを学習させる過程は 観測→重み付け→行動決定→報酬
です。

公式が出しているサンプルシーンのうちテニスについてコードの解説をします。

強化学習を行うゲームオブジェクトにはMonoBehaviourクラスを継承してクラスを作成するのではなくAgentクラスを継承しoverrideして作成する。
(public class TennisAgent : Agent)

Agentクラスのメソッド紹介

public override void Initialize() :エージェント初期化時
public override void OnEpisodeBegin() :エピソード開始時
public override void CollectObservations() :エージェントが観察するデータの収集時
public override void OnActionReceived(ActionBuffers actionBuffers) :観察によって収集した情報を決定した後にアクションを起こす時
public override void Heuristic(in ActionBuffers actionsOut) :学習状態でない時にテストプレイヤーがキャラクターを操作したりを切り替える時に使用

1エピソードとは?
→ MaxStepで指定したステップ回数に達した(目的達成できなかった)またはEndEpisode関数を呼び出した(目的達成できた)時に学習が終了する。その区切りのこと

このプログラムの流れ
Initialize → OnEpisodeBegin → CollectObservations → OnActionReceived

Header

このコンポーネントはAgentA, AgentBにアタッチされています。

.cs
    [Header("Specific to Tennis")]
    public GameObject ball;
    public bool invertX;
    public int score;
    public GameObject myArea;
    public float angle;
    public float scale;

    Text m_TextComponent;
    Rigidbody m_AgentRb;
    Rigidbody m_BallRb;
    float m_InvertMult;
    EnvironmentParameters m_ResetParams;

    // Looks for the scoreboard based on the name of the gameObjects.
    // Do not modify the names of the Score GameObjects
    const string k_CanvasName = "Canvas";
    const string k_ScoreBoardAName = "ScoreA";
    const string k_ScoreBoardBName = "ScoreB";

Initialize

Initialize()はエージェント初期化時に呼ばれます
m_AgentRbに自身(AgentA or B)のRigidbodyを入れる
invertX でAgentA or Bを判別
m_BallRbにballのRigidbodyを入れる
Academy.Instance.EnvironmentParametersとは?
概要
前半は必要な変数とGameObjectの対応付け
後半は初期化 ラケットのz軸を55°にしボールのサイズを0.5fにしている
ラケットのリスポーンはOnEpisodeBegin()
ボールのリスポーンは別のクラスで行う

.cs
    public override void Initialize()
    {
        m_AgentRb = GetComponent<Rigidbody>();
        m_BallRb = ball.GetComponent<Rigidbody>();
        var canvas = GameObject.Find(k_CanvasName);
        GameObject scoreBoard;
        m_ResetParams = Academy.Instance.EnvironmentParameters;
        if (invertX)
        {
            scoreBoard = canvas.transform.Find(k_ScoreBoardBName).gameObject;
        }
        else
        {
            scoreBoard = canvas.transform.Find(k_ScoreBoardAName).gameObject;
        }
        m_TextComponent = scoreBoard.GetComponent<Text>();
        SetResetParameters();
    }

    public void SetResetParameters()
    {
        angle = m_ResetParams.GetWithDefault("angle", 55);
        gameObject.transform.eulerAngles = new Vector3(
            gameObject.transform.eulerAngles.x,
            gameObject.transform.eulerAngles.y,
            m_InvertMult * angle
        );

        scale = m_ResetParams.GetWithDefault("scale", .5f);
        ball.transform.localScale = new Vector3(scale, scale, scale);
    }

OnEpisodeBegin

AgentA ,Bを判別と初期位置の決定及び玉の初期値を0に設定
SetResetParameters()実行

.cs
    public override void OnEpisodeBegin()
    {
        //AgentA ,Bを判別し生成する位置を決める
        m_InvertMult = invertX ? -1f : 1f;
        //親の位置を中心とし上下方向にランダムに配置
        transform.position = new Vector3(-m_InvertMult * Random.Range(6f, 8f), -1.5f, -1.8f) + transform.parent.transform.position;
        m_AgentRb.velocity = new Vector3(0f, 0f, 0f);

        SetResetParameters();
    }

CollectObservations

学習に使用する観測値の収集
パラメータは9つ
位置,速度(2) × x,y (2) × ラケット,ボール(2)
+ラケットの傾き = 9
ラケットの位置、速度とボールの距離を与えることで打ち返す相関を学習できる
学習に使用する観測値を正規化しなくてよいのか?

.cs
    public override void CollectObservations(VectorSensor sensor)
    {
        //エリアの中心からラケットまでの距離(縦方向)
        sensor.AddObservation(m_InvertMult * (transform.position.x - myArea.transform.position.x));
        //地面からラケットの高さ
        sensor.AddObservation(transform.position.y - myArea.transform.position.y);
        //ラケットの縦方向の速さ
        sensor.AddObservation(m_InvertMult * m_AgentRb.velocity.x);
        //ラケットの横方向の速さ
        sensor.AddObservation(m_AgentRb.velocity.y);
        /* 以下ボールについて同様*/
        sensor.AddObservation(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x));
        sensor.AddObservation(ball.transform.position.y - myArea.transform.position.y);
        sensor.AddObservation(m_InvertMult * m_BallRb.velocity.x);
        sensor.AddObservation(m_BallRb.velocity.y);
        //ラケットの傾き
        sensor.AddObservation(m_InvertMult * gameObject.transform.rotation.z);
    }

OnActionReceived

学習結果から行動に移す
行動内容
1
ラケットが地面から6以上上にあり一定の速度以上なら地面方向に移動させる
2
位置計測から学習した値を元に前進,後進させる
3
位置計測から学習した値を元にラケットのZ軸を回転させる

注意
報酬は別のクラスで設定している

キャンバスのスコア更新もここで行う スコアの加算は別のクラスで変更する

.cs
    public override void OnActionReceived(ActionBuffers actionBuffers)

    {
        //学習結果を格納する continuousActions[2](BehaviorParametersで設定)
        var continuousActions = actionBuffers.ContinuousActions;
        //Clampにより-1 ~ 1 の間に値を収める
        var moveX = Mathf.Clamp(continuousActions[0], -1f, 1f) * m_InvertMult;
        var moveY = Mathf.Clamp(continuousActions[1], -1f, 1f);
        var rotate = Mathf.Clamp(continuousActions[2], -1f, 1f) * m_InvertMult;
        //ラケットの高さが地面から6以上 かつ Y方向の速さが一定以上ならば (地面のY座標は-7.5)
        if (moveY > 0.5 && transform.position.y - transform.parent.transform.position.y < -1.5f)
        {
            //地面方向に移動させる
            m_AgentRb.velocity = new Vector3(m_AgentRb.velocity.x, 7f, 0f);
        }
        //学習した値を元にX方向に前進
        m_AgentRb.velocity = new Vector3(moveX * 30f, m_AgentRb.velocity.y, 0f);
        //学習した値を元にラケットのZ軸を回転させる
        m_AgentRb.transform.rotation = Quaternion.Euler(0f, -180f, 55f * rotate + m_InvertMult * 90f);
        //ラケットAとエリア中央までの距離が-1以下なら → 相手エリアに1以上入ったならリスポーンさせる
        if (invertX && transform.position.x - transform.parent.transform.position.x < -m_InvertMult ||
            !invertX && transform.position.x - transform.parent.transform.position.x > -m_InvertMult)
        {
            //リスポーン位置
            transform.position = new Vector3(-m_InvertMult + transform.parent.transform.position.x,
                transform.position.y,
                transform.position.z);
        }
        //スコア更新
        m_TextComponent.text = score.ToString();
    }

報酬の設定

ラケットの位置計測 ラケットの行動決定はラケットにTennisAgentをアタッチし実装したが、報酬設定はball内スクリプトで設定する(壁,ラケットに当たることが報酬に関わるため)
長くなるので次の記事で解説します

閲覧いただきありがとうございました。
他にもVRChatでのステージ製作やPun2を使用したオンラインボードゲーム製作を行っていますので興味のある方ぜひチェックしてください。
以上で失礼します。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0