ML-Agents: サンプルソースコードの理解
前回の記事では,Unityで強化学習する準備として,ML-Agentsのサンプルの学習および検証をおこなった.
今回は,上記の記事において示したソースコードの理解を深めるために,どこで何をしているのかについてみていく.C#に慣れていないため,少し慣れていくこともできたらと思う.
これを理解することで,自分のロボットに強化学習を適用させるときにもソースコードをどういじっていけばよいかが分かると期待している.
RollerAgent: ソースコード
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class RollerAgent : Agent
{
Rigidbody rBody;
void Start () {
rBody = GetComponent<Rigidbody>();
}
public Transform Target;
public override void OnEpisodeBegin()
{
// If the Agent fell, zero its momentum
if (this.transform.localPosition.y < 0)
{
this.rBody.angularVelocity = Vector3.zero;
this.rBody.velocity = Vector3.zero;
this.transform.localPosition = new Vector3( 0, 0.5f, 0);
}
// Move the target to a new spot
Target.localPosition = new Vector3(Random.value * 8 - 4,
0.5f,
Random.value * 8 - 4);
}
public override void CollectObservations(VectorSensor sensor)
{
// Target and Agent positions
sensor.AddObservation(Target.localPosition);
sensor.AddObservation(this.transform.localPosition);
// Agent velocity
sensor.AddObservation(rBody.velocity.x);
sensor.AddObservation(rBody.velocity.z);
}
public float forceMultiplier = 10;
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Actions, size = 2
Vector3 controlSignal = Vector3.zero;
controlSignal.x = actionBuffers.ContinuousActions[0];
controlSignal.z = actionBuffers.ContinuousActions[1];
rBody.AddForce(controlSignal * forceMultiplier);
// Rewards
float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
// Reached target
if (distanceToTarget < 1.42f)
{
SetReward(1.0f);
EndEpisode();
}
// Fell off platform
else if (this.transform.localPosition.y < 0)
{
EndEpisode();
}
}
}
コードの理解
モジュールのインポート
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
Agentクラスの定義
public class RollerAgent : Agent
{
Rigidbody rBody;
public Transform Target;
public float forceMultiplier = 10;
...
}
Agentクラスを基底としてRollerAgentという名前のクラス
- Rigidbody型でrBodyという名前で宣言
- Transform型でTargetという名前で宣言(public)
- float型でforceMultiplierという名前で宣言(public)
開始関数
void Start () {
rBody = GetComponent<Rigidbody>();
}
GetComponent()でrBodyにインスタンス生成
エピソードの初めに実行するリセット関数
public override void OnEpisodeBegin()
{
// If the Agent fell, zero its momentum
if (this.transform.localPosition.y < 0)
{
this.rBody.angularVelocity = Vector3.zero;
this.rBody.velocity = Vector3.zero;
this.transform.localPosition = new Vector3( 0, 0.5f, 0);
}
// Move the target to a new spot
Target.localPosition = new Vector3(Random.value * 8 - 4,
0.5f,
Random.value * 8 - 4);
}
this
ここではAgentクラス内であるため,これはAgentを意味する(pythonでいうところのself)
this.transform.localPosition.y
Agentのtransformにおいて自身の位置(y=高さ)
※今回はそれが0よりも小さかったらということで落下を表現
this.rBody.angularVelocity
AgentのrBodyにおいて自身の角速度
this.rBody.Velocity
AgentのrBodyにおいて自身の並進速度
this.transform.localPosition
Agent自身の位置(x, y, z)
Vector3.zero
0の要素を3つもつベクトル
new Vector3(val1, val2, val3)
任意の要素を3つもつベクトル
数値を指定できる(小数使いたければ数字の後ろにfをつける.例:0.5f)
Target.localPosition
Targetに指定しているオブジェクト自身の位置
Random.value
0.0~1.0の乱数を生成(参考)
状態関数
public override void CollectObservations(VectorSensor sensor)
{
// Target and Agent positions
sensor.AddObservation(Target.localPosition);
sensor.AddObservation(this.transform.localPosition);
// Agent velocity
sensor.AddObservation(rBody.velocity.x);
sensor.AddObservation(rBody.velocity.z);
}
public override void CollectObservations(VectorSensor sensor)
override修飾子により,基底クラスで定義されているメソッドを上書きして新しいものにすることができる
今回は,VectorSensor型のsensorを引数にもつCollectObservationsメソッドを上書き
sensor.AddObservation()
sensorのAddObervationメソッドにより,observationに加えていく
今回は以下の4つをobservationとして追加
- Target.localPosition(ターゲットキューブの位置)
- this.transform.localPosition(エージェントの位置)
- rBody.velocity.x(エージェントの速度x方向)
- rBody.velocity.z(エージェントの速度z方向)
アクション関数
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Actions, size = 2
Vector3 controlSignal = Vector3.zero;
controlSignal.x = actionBuffers.ContinuousActions[0];
controlSignal.z = actionBuffers.ContinuousActions[1];
rBody.AddForce(controlSignal * forceMultiplier);
// Rewards
float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
// Reached target
if (distanceToTarget < 1.42f)
{
SetReward(1.0f);
EndEpisode();
}
// Fell off platform
else if (this.transform.localPosition.y < 0)
{
EndEpisode();
}
}
public override void OnActionReceived(ActionBuffers actionBuffers)
override修飾子により,基底クラスで定義されているメソッドを上書きして新しいものにすることができる
今回は,ActionBuffers 型のactionBuffersを引数にもつCollectObservationsメソッドを上書き
Vector3 controlSignal
Vector3型でcontrolSignalをVector3.zeroで初期化
今回はcontrolSignalでエージェントに与える力の入力を格納
actionBuffers.ContinuousActions[0]
actionBuffersがもつ連続値アクションの0番目
今回は,これをエージェントに与える力のx方向の割合としている
actionBuffers.ContinuousActions[1]
actionBuffersがもつ連続値アクションの1番目
今回は,これをエージェントに与える力のz方向の割合としている
rBody.AddForce(controlSignal * forceMultiplier)
エージェントに力(3軸方向から)を加える
今回は,forceMultiplierをゲインとしている
Vector3.Distance(this.transform.localPosition, Target.localPosition);
Vector3.Distanceで指定した2つの位置から距離を算出
SetReward(1.0f);
報酬を与える
EndEpisode();
エピソードを終了する
今回は,ぶつかるか,落ちるかの2パターンで終了を設けている
感想
ある程度,どこで何をしているかは確認できた.
強化学習自体は,Pythonではあるが,OpenAI Gymで自作の環境を作成したことがあるため,
ObservationやAction,Resetなどは分かる.そのため,どこがそれらにあたるものかを理解できたことは非常に大きい.C#自体は使ったことはないが,雰囲気はCやC++であるため,抵抗はなかった.実際に自作の環境を作る際には,どうやってほしい値を取得するかなどの知識が必要にはなるが,それは状況に合わせてその都度調べて使えたらと思う.
参考文献