0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ML-Agents: サンプルソースコードの理解

Posted at

ML-Agents: サンプルソースコードの理解

前回の記事では,Unityで強化学習する準備として,ML-Agentsのサンプルの学習および検証をおこなった.

今回は,上記の記事において示したソースコードの理解を深めるために,どこで何をしているのかについてみていく.C#に慣れていないため,少し慣れていくこともできたらと思う.
これを理解することで,自分のロボットに強化学習を適用させるときにもソースコードをどういじっていけばよいかが分かると期待している.

RollerAgent: ソースコード
RollerAgent
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class RollerAgent : Agent
{
    Rigidbody rBody;
    void Start () {
        rBody = GetComponent<Rigidbody>();
    }

    public Transform Target;
    public override void OnEpisodeBegin()
    {
       // If the Agent fell, zero its momentum
        if (this.transform.localPosition.y < 0)
        {
            this.rBody.angularVelocity = Vector3.zero;
            this.rBody.velocity = Vector3.zero;
            this.transform.localPosition = new Vector3( 0, 0.5f, 0);
        }

        // Move the target to a new spot
        Target.localPosition = new Vector3(Random.value * 8 - 4,
                                           0.5f,
                                           Random.value * 8 - 4);
    }

		public override void CollectObservations(VectorSensor sensor)
		{
		    // Target and Agent positions
		    sensor.AddObservation(Target.localPosition);
		    sensor.AddObservation(this.transform.localPosition);
		
		    // Agent velocity
		    sensor.AddObservation(rBody.velocity.x);
		    sensor.AddObservation(rBody.velocity.z);
		}

		
		public float forceMultiplier = 10;
		public override void OnActionReceived(ActionBuffers actionBuffers)
		{
		    // Actions, size = 2
		    Vector3 controlSignal = Vector3.zero;
		    controlSignal.x = actionBuffers.ContinuousActions[0];
		    controlSignal.z = actionBuffers.ContinuousActions[1];
		    rBody.AddForce(controlSignal * forceMultiplier);
		
		    // Rewards
		    float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
		
		    // Reached target
		    if (distanceToTarget < 1.42f)
		    {
		        SetReward(1.0f);
		        EndEpisode();
		    }
		
		    // Fell off platform
		    else if (this.transform.localPosition.y < 0)
		    {
		        EndEpisode();
		    }
		}

}

コードの理解

モジュールのインポート

using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

Agentクラスの定義

public class RollerAgent : Agent
{
    Rigidbody rBody;
    public Transform Target;
	public float forceMultiplier = 10;

...
}

Agentクラスを基底としてRollerAgentという名前のクラス

  • Rigidbody型でrBodyという名前で宣言
  • Transform型でTargetという名前で宣言(public)
  • float型でforceMultiplierという名前で宣言(public)

開始関数

void Start () {
    rBody = GetComponent<Rigidbody>();
}

GetComponent()でrBodyにインスタンス生成


エピソードの初めに実行するリセット関数

public override void OnEpisodeBegin()
{
   // If the Agent fell, zero its momentum
    if (this.transform.localPosition.y < 0)
    {
        this.rBody.angularVelocity = Vector3.zero;
        this.rBody.velocity = Vector3.zero;
        this.transform.localPosition = new Vector3( 0, 0.5f, 0);
    }

    // Move the target to a new spot
    Target.localPosition = new Vector3(Random.value * 8 - 4,
                                       0.5f,
                                       Random.value * 8 - 4);
}

this

ここではAgentクラス内であるため,これはAgentを意味する(pythonでいうところのself)

this.transform.localPosition.y

Agentのtransformにおいて自身の位置(y=高さ)

※今回はそれが0よりも小さかったらということで落下を表現

this.rBody.angularVelocity

AgentのrBodyにおいて自身の角速度

this.rBody.Velocity

AgentのrBodyにおいて自身の並進速度

this.transform.localPosition

Agent自身の位置(x, y, z)

Vector3.zero

0の要素を3つもつベクトル

new Vector3(val1, val2, val3)

任意の要素を3つもつベクトル

数値を指定できる(小数使いたければ数字の後ろにfをつける.例:0.5f)

Target.localPosition

Targetに指定しているオブジェクト自身の位置

Random.value

0.0~1.0の乱数を生成(参考


状態関数

public override void CollectObservations(VectorSensor sensor)
{
    // Target and Agent positions
    sensor.AddObservation(Target.localPosition);
    sensor.AddObservation(this.transform.localPosition);

    // Agent velocity
    sensor.AddObservation(rBody.velocity.x);
    sensor.AddObservation(rBody.velocity.z);
}

public override void CollectObservations(VectorSensor sensor)

override修飾子により,基底クラスで定義されているメソッドを上書きして新しいものにすることができる

今回は,VectorSensor型のsensorを引数にもつCollectObservationsメソッドを上書き

sensor.AddObservation()

sensorのAddObervationメソッドにより,observationに加えていく

今回は以下の4つをobservationとして追加

  • Target.localPosition(ターゲットキューブの位置)
  • this.transform.localPosition(エージェントの位置)
  • rBody.velocity.x(エージェントの速度x方向)
  • rBody.velocity.z(エージェントの速度z方向)

アクション関数

public override void OnActionReceived(ActionBuffers actionBuffers)
{
    // Actions, size = 2
    Vector3 controlSignal = Vector3.zero;
    controlSignal.x = actionBuffers.ContinuousActions[0];
    controlSignal.z = actionBuffers.ContinuousActions[1];
    rBody.AddForce(controlSignal * forceMultiplier);

    // Rewards
    float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);

    // Reached target
    if (distanceToTarget < 1.42f)
    {
        SetReward(1.0f);
        EndEpisode();
    }

    // Fell off platform
    else if (this.transform.localPosition.y < 0)
    {
        EndEpisode();
    }
}

public override void OnActionReceived(ActionBuffers actionBuffers)

override修飾子により,基底クラスで定義されているメソッドを上書きして新しいものにすることができる

今回は,ActionBuffers 型のactionBuffersを引数にもつCollectObservationsメソッドを上書き

Vector3 controlSignal

Vector3型でcontrolSignalをVector3.zeroで初期化

今回はcontrolSignalでエージェントに与える力の入力を格納

actionBuffers.ContinuousActions[0]

actionBuffersがもつ連続値アクションの0番目

今回は,これをエージェントに与える力のx方向の割合としている

actionBuffers.ContinuousActions[1]

actionBuffersがもつ連続値アクションの1番目

今回は,これをエージェントに与える力のz方向の割合としている

rBody.AddForce(controlSignal * forceMultiplier)

エージェントに力(3軸方向から)を加える

今回は,forceMultiplierをゲインとしている

Vector3.Distance(this.transform.localPosition, Target.localPosition);

Vector3.Distanceで指定した2つの位置から距離を算出

SetReward(1.0f);

報酬を与える

EndEpisode();

エピソードを終了する

今回は,ぶつかるか,落ちるかの2パターンで終了を設けている

感想

ある程度,どこで何をしているかは確認できた.
強化学習自体は,Pythonではあるが,OpenAI Gymで自作の環境を作成したことがあるため,
ObservationやAction,Resetなどは分かる.そのため,どこがそれらにあたるものかを理解できたことは非常に大きい.C#自体は使ったことはないが,雰囲気はCやC++であるため,抵抗はなかった.実際に自作の環境を作る際には,どうやってほしい値を取得するかなどの知識が必要にはなるが,それは状況に合わせてその都度調べて使えたらと思う.

参考文献

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?