LoginSignup
4
4

More than 3 years have passed since last update.

UnityでPytorchライクの機械学習ライブラリを作る。7日目:Linearクラスの実装

Last updated at Posted at 2021-02-24

はじめに

 今回はいよいよニューラルネットワークの基本とも言える、Linear層を実装したいと思います。他の実装と比べて特殊な点は、インスタンスごとにParamデータを格納しておき、学習時にそれらの更新ができるようになっているという点です。

Linearが行っている操作

 全結合層(LinearまたはAffine, Fully Connected Layerなど)で行われているのは以下のような操作です。これらは基本的に今まで実装した関数で実装できると思いましたが、バイアスの足し算の際に次元を補完するためにRepeat関数を追加します。

\boldsymbol{O} = {\bf{}A\cdot W} + \bf{B}

 これは実装上では以下のように表されます。$*$は0またはそれより大きい数の次元が入ります。

O[*, H_{out}] = A[*, H_{in}]\cdot W[H_{in}, H_{out}]+B[*, H_{out}]

 また、学習の際にパラメータを操作するために、ParamersプロパティをIFunctionに設定します。

次元補完用のRepeatクラスを実装

 Repeatクラスの実装は集合演算のクラスと似ていて、対象となる次元と繰り返し回数を否定することで指定した次元方向に複製したTensorを出力します。

実装

Repeat.cs
namespace Rein.Functions.Process{
    public class Repeat: UnaryFunction{
        int Dim, RepeatNum, Step;
        public Repeat(int dim, int rep): base($"Expand-d:{dim}-r:{rep}"){
            this.Dim = dim;
            this.RepeatNum = rep;
        }

        protected override Tensor UnaryForward(Tensor tensor)
        {
            if (tensor.Shape.Count < this.Dim - 1) throw new InvalidShapeException();
            R[] data = new R[tensor.Size * this.RepeatNum];
            // 対象となるDim以下の要素数を取得
            this.Step = tensor.Shape[this.Dim] * tensor.Size / tensor.Shape.GetRange(0, this.Dim + 1).Aggregate((now, next) => now * next);

            for (int i = 0; i < tensor.Size; i += this.Step){
                for (int j = 0; j < this.RepeatNum; j++){
                    Array.Copy(tensor.Data, i, data, i * this.RepeatNum + j * this.Step, this.Step);
                }
            }

            List<int> shape = new List<int>(tensor.Shape);
            shape[this.Dim] *= this.RepeatNum;

            return new Tensor(data, shape);
        }

        protected override void UnaryBackward()
        {
            for (int i = 0; i < this.In.Size; i += this.Step){
                for (int j = 0; j < this.RepeatNum; j++){
                    for (int k = 0; k < this.Step; k++){
                        this.In.Grad[i + k] += this.Out.Grad[i * this.RepeatNum + j * this.Step + k];
                    }
                }
            }
        }
    }
}

学習パラメータの追加

 学習対象となるパラメータを持った関数からパラメータを取り出すためにIFunctionインターフェースにParametersプロパティとBaseFunctionにParamsメンバを追加しておきます。

IFunctionの変更

IFunction.cs
namespace Rein.Functions{
    public interface IFunction{
        // ...
        public Tensor[] Parameters {get;}
    }
}

BaseFunctionの変更

BaseFunction.cs
namespace Rein.Functions
{
    public abstract class BaseFunction: IFunction
    {
        // ...
        public Tensor[] Params = new Tensor[]{};
        // ...
        public Tensor[] Parameters{
            get{
                return this.Params;
            }
        }
    }
}

 これでIFunction.Parametersにアクセスできるようになりました。

Linearクラスの実装

 これでLinearクラスに必要なものが揃いましたので実装していきます。ここで問題となるのが、Linearがどのクラスを継承するのかです。LinearクラスのForwardでは、DotとBiasの加算が行われますが、これはすでに実装しており、Linearクラスはこれら二つの演算と二つのパラメータを格納するクラスとなっています。そのためBaseFunctionのForward処理をオーバーライトする必要があります。

実装

Linear.cs
namespace Rein.Functions.Layer{
    public class Linear: UnaryFunction{
        public Tensor Weight, Bias;

        public Linear(int inputSize, int outputSize, bool bias = true): base("Linear"){
            this.Weight = new Tensor(new int[]{inputSize, outputSize});

            if (bias){
                this.Bias = new Tensor(new int[]{outputSize});
            }

            this.Params = new Tensor[]{Weight, Bias};
        }

        public override Tensor[] Forward(Tensor[] tensors)
        {
            Tensor outTensor = F.Dot(tensor[0], this.Weight);

            if (this.Bias != null){
                if (outTensor.Shape.Count == 0){
                    outTensor += this.Bias;
                }else{
                    outTensor += this.Bias.Repeat(0, outTensor.Size / outTensor.Shape.Last());
                }
            }

            return outTensor;
        }

        public override void Backward()
        {
            // 何もしない
        }
    }
}

終わりに

 今回はLinear関数の実装をしました。やっとパラメータを持った関数を実装することができました。ここで改めて小目標として設定した「DQNを実装する」ですが、現時点でこれに必要なのは

  • optimizer(簡単なものを二つほど実装する)
  • detachメソッド(勾配を伝播させないようにするためのメソッド)
  • ロス関数(二乗誤差などは既に実装できますが)
  • ネットワークのパラメータを保存、読み込み

の四つぐらいですかね。上三つに関してはこれまでの延長線上で出来そうな気がするのですが、パラメータの保存、読み込み辺りはどういう形式で保存すれば良いか分からないです。Pytorchとかではどう実装してるんでしょうか。

 そこら辺は追々考えることにして次回はdetachメソッド、ロス関数あたりを実装していきます。

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4