1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

RLlibで強化学習したモデルをC#で使う

Posted at

目的

Ray RLlibで強化学習したモデルをC#で動かせるようにします。

背景

AI・機械学習は python でやるのがライブラリ等、環境が整備されていて一番楽だと思います。
一方でゲーム開発はUnityをよく使うと思いますが、こちらはC#でスクリプトを書きます。

自作ゲームにつよつよAIを導入したくないですか?したいですよね。
僕がまさにその状況にあるので、pythonで学習したモデルをC#に移植する方法を調べてみました。

注意
ここでは僕の用途の関係上、RLlibという強化学習ライブラリで学習させたAIを、C#で使う方法について書きます。
学習済みモデルをC#で使うだけならMicrosoftのこの記事を読んだ方が早いです。
本記事もこのMicrosoftの記事をめちゃくちゃ参考にして書いています。

具体的な目標

PPOエージェントを用いてSimpleCorriderという環境に対して強化学習を実施して、出来上がった学習済みモデルをC#のML.Netライブラリを使って動かします。

作業の概要

pythonでやること: RLlibで強化学習 → onnx形式でエクスポート
C#でやること: onnx形式のモデルをインポート → 学習した通りの結果をPredictできることの確認

【Python】RLlibで強化学習

RLlibのGetting Started Guideにある例で強化学習モデルを作成します。

必要なパッケージ

以下のパッケージを使います。

  • Ray: 強化学習
  • Tensorflow: ニューラルネットワーク作成・学習
  • tf2onnx: onnx形式への変換
pip install "ray[rllib]" tensorflow tf2onnx

SimpleCorrider環境

これから強化学習するSimpleCorrider環境を説明します。
SimpleCorriderは動作確認用の環境で、面白いものではないです。
直線状にスタート地点とゴール地点があり、スタート地点から開始してプラスかマイナスのどちらかに進むという2つのアクションをとります。
行動のたびにペナルティとしてマイナスの報酬が課され、ゴール地点までたどり着ければプラスの報酬がもらえます。

  • オブザベーション[int]: 現在の位置
  • アクション[int]: 0 or 1 (0のとき-1, 1のとき+1進む)
  • 報酬[float]: ゴールしたら+1.0, それ以外は -0.1

うまく学習できれば、ポリシーは1を連打するわけです。

学習したモデルのエクスポート

Ray Getting Started GuideのSimpleCorriderを強化学習している使用例のコードをコピペしてください。Ray ML Quick StartのRLlibのドロップダウンリストになっている箇所です。

このコードはPPOTrainerで学習し、最後に学習済みモデルを使って1エピソードを走らせて累積報酬を求めるという内容になっています。この学習の後(trainer.train()のあるfor文を抜けた後)に以下のように追記します。

追記する
outdir = "exported_onnx" # 適当なディレクトリ
trainer.export_policy_model(outdir, onnx=11)

追記したスクリプトを実行すれば、
outdirディレクトリにsaved_model.onnxというファイル名で学習済みモデルがonnx形式でエクスポートされます。
ただし次のセクションの内容もスクリプトに追加してからの実行をおすすめします。

対応するグラフの変数名を調べる

この後C#にモデルをインポートするわけですが、ここでいうモデルというのは、ニューラルネットワークのグラフのことです。
グラフにC#から入出力するには、グラフ上の変数名を指定してやる必要があります。
なので変数名を調べておきましょう。

policy = trainer.get_policy()

print(policy._obs_input)
# >> Tensor("default_policy/obs:0", shape=(?, 1), dtype=float32)
print(policy._sampled_action)
# >> Tensor("default_policy/cond_1/Merge:0", shape=(?,), dtype=int64)

PPOエージェントの場合はpolicy._obs_inputにあるdefault_policy/obs:0がオブザベーションに対応するグラフの入力名、policy._sampled_actionにあるdefault_policy/cond_1/Merge:0がアクションに対応するグラフの出力名になります。

注意
今回はアクションが離散値なので、default_policy/cond_1/Merge:0が直接アクションに対応しますが、例えば境界のある連続値(gym envでいうBox)であればスケーリング処理はグラフの外側で行われる仕様のため、C#上でグラフの出力を期待したアクションの形式になるように変換する必要があります。

【C#】学習済みonnxモデルの実行

先ほどエクスポートしたsaved_model.onnxをC#で動かします。

必要なパッケージ

以下のパッケージを使います。

  • Microsoft.ML
  • Microsoft.ML.OnnxRuntime
  • Microsoft.ML.OnnxTransformer

学習モデルのインポート

ML.Netでモデルを予測させるのは結構面倒です。
Load() → Predict()くらいシンプルにならんもんかね。。。

入出力の定義

まずグラフの入力と出力を定義します

Program.cs
public class OnnxInput
{
    [ColumnName("default_policy/obs:0")]
    public float CurPos { get; set; }
}

public class OnnxOutput
{
    [ColumnName("default_policy/cond_1/Merge:0"), OnnxMapType(typeof(Int64), typeof(Single))]
    public Int64[] Action { get; set; }
}

Attributeにグラフとの対応を指定します。変数名や型については「対応するグラフの変数名を調べる」のセクションを参照ください。
float型は型変換の必要はないですが、それ以外の場合はOnnxMapTypeで変換しないとエラーになります。

インポート

Predict()メソッドを持っているのはPredictEngineというオブジェクトです。
このオブジェクトを作るまでが長いです。おまじないだと思って我慢しましょう。
以下のようにします。

Program.csのProgramクラス
static PredictionEngine<OnnxInput, OnnxOutput> CreatePredictEngine()
{
    string modelPath = "..../saved_model.onnx"; // Python編でエクスポートしたonnxファイルを指定

    var mlContext = new MLContext();
    var tensorFlowModel
        = mlContext.Transforms.ApplyOnnxModel(modelPath);
    var emptyDv = mlContext.Data.LoadFromEnumerable(new OnnxInput[] { });
    var trnsf = tensorFlowModel.Fit(emptyDv);
    var engine = mlContext.Model.CreatePredictionEngine<OnnxInput, OnnxOutput>(trnsf);

    return engine;
}

愚痴
var trnsf = tensorFlowModel.Fit(emptyDv); この行やばすぎる。
trnsfはOnnxTranformerという名前のクラスのインスタンスなんだけど、NNでTransformerといったらAll you needしか思いつかん。それにFit()はいまにも学習しそうな名前。ややこしいのやめて><

Predict

ここまで来れば予測は簡単です。
さっき作成した入出力クラスで型付けされるのでpythonよりも整然でエレガント。

var predictEngine = CreatePredictEngine()
var onnxInput = new OnnxInput { CurPos = obs };
var onnxOutput = predictEngine.Predict(onnxInput);

動作確認

ちゃんとonnxモデルを読み込めていて、動作するかを確認します。
せっかくなのでSimpleCorrider EnvのC#版を作ってみます。

SimpleCorriderクラス(C#版)

Program.cs
public class SimpleCorridor
{
    public int EndPos { get; set; }
    public int CurPos { get; set; }

    public SimpleCorridor(int corridorLength)
    {
        this.EndPos = corridorLength;
        this.CurPos = 0;
    }

    public int Reset()
    {
        this.CurPos = 0;
        return this.CurPos;
    }

    public (int, double, bool) Step(int action)
    {
        this.CurPos += ((action == 0) && (this.CurPos > 0)) ? -1 : 1;
        var done = (this.CurPos >= this.EndPos);
        var reward = done ? 1.0 : -0.1;
        return (this.CurPos, reward, done);
    }
}

動作確認

インポートしたモデルで1エピソード実行してみます。

Program.cs
static void Main(string[] args)
{
    var env = new SimpleCorridor(10);
    var predictEngine = CreatePredictEngine();

    var obs = env.Reset();
    var done = false;
    double reward;
    var totalReward = 0.0;
    while (!done)
    {
        var onnxInput = new OnnxInput { CurPos = obs };
        var onnxOutput = predictEngine.Predict(onnxInput); // 予測
        var action = (int)onnxOutput.Action[0];  // intに変換
        (obs, reward, done) = env.Step(action);
        totalReward += reward;
    }

    Console.WriteLine("Total Reward = {0}", totalReward);
}

これで Total Reward がpythonのときと同じくらいになれば成功です。
おつかれさまでした。

参考文献

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?