6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

UnityでPytorchライクの機械学習ライブラリを作る。 1日目:Tensorクラスの実装その1

Last updated at Posted at 2021-02-14

1. はじめに

 私は物理エンジンを使って強化学習をすることに興味があって、Unityを触っていたのですが、Unityでの強化学習はml-agentが主で自分でも使ってみたのですが、一つ納得行かない点が出てきました。
ランタイムでモデルを新しく作れない
 そもそもml-agentではモデル(学習モデルも物理モデルも)があらかじめ決まっているものに対して学習を行いビルド後は推論することを目的としているので通常は問題ないのですが自分の用途ではビルド後も学習できるか、最低でもランタイムでモデル定義ができるような物が望ましいです。

 そこでUnity用に機械学習のライブラリ(実際はパッケージ)をとして作ることにしました。
 ライブラリを作るにあたって達成したい条件は以下の4つです。

  • PytorchのようにDefine-by-Runの機械学習ライブラリ
  • ネットワークのモデルはLinearなどの単純なものだけで良しとする(強化学習で使うのでConvolutionなどの層は無しで)
  • ネットワークの最適化手法は性能が良さそうなものを一つか二つ実装。
  • MinやMeanなどの関数についても強化学習で用いるものをかいつまんで実装する。

 欲を言えば最終的な目標は

 全ての強化学習モデルの実装

 ですが、ひとまずのゴールとしては

 DQNを実装すること

 これを見据えて実装していこうと思います。

1.2. 環境

MacBook Air (Retina, 13-inch, 2018)
プロセッサ 1.6 GHz デュアルコアIntel Core i5
メモリ 8 GB 2133 MHz LPDDR3
Unity 2019.4.18f1

1.3. 準備

 ライブラリを作るための準備をします。こちらを参考にしました。
 【Unity:PackageManager】自作Packageの作成と追加

1.4 フォルダ構成

Rein
┣ package.json
┗ Scripts
  ┣ Editor
  ┣ Runtime
    ┣ Utils
      ┗ Exceptions.cs
    ┣ Functions
      ┗ IFunction.cs
    ┗ Tensor.cs
  ┗ Test

大体ランタイムに書いていきます。

2. Tensorクラスの実装

 ここで実装するTensorは機械学習で使われる計算グラフの実装のための重要な材料の一つです。

2.1. 計算グラフの考え方

 実装の前に機械学習のライブラリで当たり前に使われている計算グラフについての説明を行います。自分なりに噛み砕いたのですが少し分かり辛いかもしれません。
 例えば以下のようなグラフを考えます。
ComputeGraph.jpg

これは式で表すと以下のようになります。

$$
z=(x+y)+(x*y)
$$

この時

$$
\frac{\partial z}{\partial x}=\frac{\partial(x+y)}{\partial x}+\frac{\partial(x*y)}{\partial x}
$$

となるが、計算グラフにおいてはこの処理を分割して行います。この例で言えば足し算と割り算を一つの関数として、

Add(x,\, y) = x + y\\
Mul(x,\, y) = x * y\\
z = Add(Add(x,\, y),\, Mul(x,\, y))

として、Forward処理の時には通常の計算を行い、Backwardの時には出力の勾配と関数自体の勾配をチェインルールを使用して勾配を伝播させる。この例で言えば

z_1=Add(x,\, y)\\
z_2=Mul(x,\, y)\\
z=Add(z_1,\, z_2)

として

\frac{\partial z}{\partial x}=\frac{\partial z}{\partial z_1}\frac{\partial z_1}{\partial x}+\frac{\partial z}{\partial z_2}\frac{\partial z_2}{\partial x}=\frac{\partial Add(z_1, z_2)}{\partial z_1}\frac{\partial Add(x, y)}{\partial x}+\frac{\partial Add(z_1, z_2)}{\partial z_1}\frac{\partial Mul(x, y)}{\partial x}\\
\frac{\partial z}{\partial x}=1\times1+1\times y=1+y

$z$のBackwardを呼び出した際には、

$z$を出力した$Add$の微分関数を呼び出し、この時の$Add$の入力$z_1,z_2$を代入し、$z_1,z_2$のBackwardを呼び出す。
そして$z_1$を出力した$Add$の...と再帰的に処理していくことで$z$に対する全てのパラメータの微分を得ることができる。
このように分割して計算するように実装すれば計算グラフを構成する関数は通常のForward用の関数とBackward用の微分した関数を用意しておき、入力を格納しておけば全ての入力に対する微分値が得られるようになる。

2.2 実装

 実際に上の理論に従って実装していきます。しかし機械学習では入力は多次元の配列として使用することが多いので、変数用のクラスとしてPytorchのようにTensorを作っていきたいのですが、ここで問題があります。
 C#では多次元の配列自体は存在するのですが配列の形(Pytorchではshape)が可変のものは存在しないのです。これを解決するためにデータを一次元の配列として持っておき、それとは別にshapeのデータを保持しておくように実装します。

2.2.1 初期化まで

 初期化処理まで書きました。ちなみにReinというのは実装するライブラリの名前です。

Tensor.cs
using System.Linq;
using System.Collections.Generic;
using System;
using System.Runtime.Serialization;
using R = System.Double;
using Rein.Utils.Exceptions;
using Rein.Functions;

namespace Rein
{
    [Serializable]
    public partial class Tensor
    {
        // Dataは変数データ、Gradは勾配データを格納する
        public R[] Data, Grad;
        // Shapeはデータの形を保存している
        public List<int> Shape;
        public int Size;
        // UseCountは計算グラフで使用された回数を保存することで、勾配の計算漏れを防ぐ
        public int UseCount = 0;
        // Backward時に呼び出す。IFunctionはRein.Functionsのinterface
        public IFunction BackFunction;

        // データの形で初期化
        public Tensor(int[] shape)
        {
            System.Random random = new System.Random();
            this.Shape = shape.ToList();
            this.Size = shape.Aggregate((now, next) => now * next);
            // 乱数で初期化
            this.Data = Enumerable.Range(0, this.Size).Select(_ => (R)random.NextDouble()).ToArray();
            this.Grad = new R[this.Size];
        }

        // データを直接入力して初期化
        public Tensor(R[] data)
        {
            this.Shape = new List<int>(1){ data.Length };
            this.Size = data.Length;
            this.Data = data;
            this.Grad = new R[this.Size];
        }

        // データとshapeで初期化
        public Tensor(R[] data, int[] shape)
        {
            this.Shape = shape.ToList();
            this.Size = shape.Aggregate((now, next) => now * next);
            // データ自体のサイズとshapeから得られるサイズが異なる時にエラーを投げる。
            if(data.Length != this.Size) throw new InvalidSizeException();
            this.Data = data;
            this.Grad = new R[this.Size];
        }

        // Shapeがリスト型で与えられた時
        public Tensor(R[] data, List<int> shape){
            this.Shape = shape;
            this.Size = shape.Aggregate((now, next) => now * next);
            if(data.Length != this.Size) throw new InvalidSizeException();
            this.Data = data;
            this.Grad = new R[this.Size];
        }
    }
}


IFunctionとInvalidSizeExceptionはまだ実装していないので後で実装しています。

2.2.2 Backward処理

 Backwardについては簡単です。Tensorを入力としている全ての関数についてBackwardを行ったら(つまりUseCountが0になったら)自身のBackFunctionのBackwardを呼び出します。

Tensor.cs
public partial class Tensor{
    //...
    public void Backward()
    {
        // BackFunctionが存在しない時は終了
        if(this.BackFunction == null)return;
        this.UseCount--;
        // 他の関数に対しても出力している場合にはまだ勾配を計算しない
        if(this.UseCount != 0)return;
        this.BackFunction.Backward();
    }
}

3. IFunctionの実装

 これもまたライブラリの根幹をなすインターフェースで計算グラフはTensorとFunctionで構成できるのでこれもまた最重要と言える部品です。とは言えインターフェースなので本当に最小限だけ実装します。

IFunction.cs
namespace Rein.Functions
{
    public interface IFunction
    {
        public Tensor[] Forward(Tensor[] inputs);
        public void Backward();
    }
}

 これだけです。以下のような図を考えると理解しやすいと思います。
auto_grad.jpg
 FunctionはForward処理の際に入力$(X,Y)$と出力$(Z,W)$を保存しておきBackward処理が出力の個数分(上図では2回)呼び出されたらFunctionの勾配を計算して入力$(X,Y)$に伝播させます。流れとしては以下の通りです。

$F.Forward([X,Y])$ ($F$に出力と入力が保存)
最終的な出力の$Backward$が呼び出される。
...(再帰的に$Backward$が行われれていく)
$W.Backward()$
 →$F.Backward()$ (何もしない)
$Z.Backward()$
 →$F.Backward()$ (伝播処理を開始)
  →$X.Backward()$
  →$Y.Backward()$

 あと、関数に対する入出力はTensorの配列を想定しています。そうしないと分岐するネットワーくを構成することが難しくなるので。

4. コード

ここまでのコードはhttps://github.com/aokyut/Rein/tree/v0.0.1に置いてあります。このURLの最後に.gitをつけたものをpackage managerからインポートすることができます。

5. まとめ

 というわけで今回は関数のインターフェースの実装まで行いました。今までPythonなどで実装をする時にはライブラリが充実していて細かい処理の中身までは考えていなかったのでこういう基礎の部分をゼロから作っていくのは楽しいです。
 ちなみに私は「ゼロから作るDeep Learning」は読んだことは無いのでもしかしたら内容が被っているかもしれませんがその時は教えていただければ修正します。
 次はTensorにおける四則演算の実装をしたいと思います。

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?