3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

UnityでPytorchライクの機械学習ライブラリを作る。8日目:ロス関数とTensor操作

Posted at

はじめに

 今回はロス関数とTensorに対する操作(detach, reshapeなど)を実装したいと思います。

その前に

 実装でロス関数をBackward処理を行う際に

Tensor loss = LossFunction(y, target)
loss.Backward();

 と呼び出したいのですが、そのままだとBackward処理が連鎖していかない(起点となる処理が必要)ので、Tensorクラスで現在Backward関数としている関数をBackwardChainに変更し新たにBackwardを追加します。

Tensor.cs
namespace Rein
{
    [Serializable]
    public partial class Tensor
    {
        public void BackwardChain()
        {
            // BackFunctionが存在しない時は終了
            if(this.BackFunction == null)return;
            this.UseCount--;
            // 他の関数に対しても出力している場合にはまだ勾配を計算しない
            if(this.UseCount != 0)return;
            this.BackFunction.Backward();
        }

        public void Backward(){
       // 一つの変数しか持たないことを確認する
            if (this.Size > 1)throw new InvalidSizeException($"expect size : 1, but actual : {this.Size}");
            this.Grad[0] = 1.0;
            this.BackFunction.Backward();
        }
    }
}

Loss関数の実装

 それではいくつか主となるLoss関数を実装していきます。基本的にはLossの関数をLambdaで計算した後SumやMeanを計算することになります

MSELoss(二乗誤差)

 これは以下のような関数です

Loss_{MSE}=\frac{1}{n}\sum_{i=1}^{n}(y_i-t_i)^2

MSELossの実装

F.cs内部に直接LambdaFunctionとして実装していきます。

F.cs
namespace Rein{
    public static class F{
        public static Tensor MSELoss(Tensor In){
            return new Lambda(
                "MSELoss",
                (x) => x * x,
                (x) => 2 * x
            ).Forward(In)[0].Mean();
        }
    }
}

HuberLossの実装

HuberLossは以下のような計算を行います。

f_{huber}(x) = \left\{
\begin{array}{ll}
\frac{1}{2}x^2 & (-\delta \leq x \leq \delta) \\
\delta|x|-\frac{1}{2}\delta^2 & (x \lt -\delta \, or\, x \gt \delta)
\end{array}
\right.\\
L_{huber}=\frac{1}{n}\sum_{i=1}^{n}f_{huber}(y_i-t_i)

 実装

 こちらも同様にF.csに加えていきます。

F.cs
namespace Rein{
    public static class F{
        public static Tensor HuberLoss(Tensor left, Tensor right, R delta = 1.0){
            R deltaSquare = delta * delta / 2;
            return new Lambda(
                "HuberLossFunction",
                new Func<R, R>((x) => 
                x < -delta ? -delta * x - deltaSquare : 
                (x > delta ? delta * x - deltaSquare : x * x / 2)),
                new Func<R, R>((x) => 
                x < -delta ? -delta :
                (x > delta ? delta : x))
                ).Forward(left - right);
        }public static Tensor HuberLoss(Tensor left, Tensor right, R delta = 1.0){
            R deltaSquare = delta * delta / 2;
            return new Lambda(
                "HuberLossFunction",
                new Func<R, R>((x) => 
                x < -delta ? -delta * x - deltaSquare : 
                (x > delta ? delta * x - deltaSquare : x * x / 2)),
                new Func<R, R>((x) => 
                x < -delta ? -delta :
                (x > delta ? delta : x))
                ).Forward(left - right);
        }
    }
}

Tensorの操作

 次はいくつかTensorの構造に作用する関数を実装していきたいと思います。Tensorの操作を行うメソッドでは基本的にShapeに作用するためDataの中身を変えないため、入力したTensorと同じインスタンスが出力されることとなります。

Detach

 これはTensorの依存関係を切り離し、勾配の伝播を止める操作です。要は学習はさせないがネットワークの出力だけ欲しいという時に使う関数です。これをTensorの関数として実装したいのですが、一つ問題があります。
例えば以下のような形式で使用するとします。

Tensor y = network(x).detach();
Tensor z = network(t);
Tensor loss = (y - z) * (y - z);
loss.Backward();

 ここでTensor yは独立したBackFuncを持たないTensorとなるのですが、network内部ではxが入力された時に計算グラフが作られ保存されているので、これらの関係を解消するためには一々yからグラフを遡る必要が出てきます。
 そのため、残念ながらTensorの操作としてのDetach操作は断念せざるを得ません。
 そこで、代わりにBaseFunctionに「勾配情報を保存しないForward」を定義します。これをPredictとします。

実装(IFunction.csの追記)

 まずIFunctionに対してPredictを追加します。

IFunction.cs
namespace Rein.Functions
{
    public interface IFunction
    {
        public Tensor[] Forward(params Tensor[] inputs);

        public Tensor[] Predict(params Tensor[] inputs);
        public void Backward();

        public Tensor[] Parameters {get; }
    }
}

実装(BaseFuncttion.csの追記)

 IFunctionに追加した関数の詳細をBaseFunctionで定義します。

BaseFunction.cs
namespace Rein.Functions
{
    public abstract class BaseFunction: IFunction
    {
        // ...
        public virtual Tensor[] Predict(params Tensor[] inputs){
            return this.FunctionForward(inputs);
        }
        // ...
    }
}

 これを使用することで、学習時に勾配を計算させないようにすることができます。PytorchのようにDetachをTensorの操作として呼び出したいなら、計算グラフの実装方法を変える必要があるようです。

Squeeze・Unsqueeze

 SqueezeはTensorのある軸方向のサイズが1の時にその軸を消し次元を減らす操作で、
 Unsqueezeは逆に次元を増やす操作です。これらも同様に関数クラスとして実装しTensorから呼び出せるようにしておきます。

Squeezeの実装

Squeeze.cs
namespace Rein.Functions.Process{
    public class Squeeze: UnaryFunction{
        private List<int> InShape;
        private int Dim;
        public Squeeze(int dim): base($"Squeeze-{dim}"){
            this.Dim = dim;
        }

        protected override Tensor UnaryForward(Tensor tensor)
        {
            this.InShape = new List<int>(tensor.Shape);
            if(tensor.Shape[this.Dim] == 1)tensor.Shape.RemoveAt(this.Dim);
            return tensor;
        }

        protected override void UnaryBackward()
        {
            this.In.Shape = this.InShape;
        }
    }
}

Unsqueezeの実装

Unsqueeze
namespace Rein.Functions.Process{
    public class Unsqueeze: UnaryFunction{
        private List<int> InShape;
        private int Dim;
        public Unsqueeze(int dim): base($"Unsqueeze-{dim}"){
            this.Dim = dim;
        }

        protected override Tensor UnaryForward(Tensor tensor)
        {
            this.InShape = new List<int>(tensor.Shape);
            tensor.Shape.Insert(this.Dim, 1);
            return tensor;
        }

        protected override void UnaryBackward()
        {
            this.In.Shape = this.InShape;
        }
    }
}

Reshape

 ReshapeでもSqueezeと同様にTensorのデータは変えずにShapeのみを入れ替えることになります。

実装

Reshape.cs
namespace Rein.Functions.Process{
    public class Reshape: UnaryFunction{
        private List<int> OutShape;
        private List<int> InShape;
        public Reshape(List<int> shape): base($"Reshape-({string.Join(",", shape)})"){
            this.OutShape = shape;
        }

        protected override Tensor UnaryForward(Tensor tensor)
        {
            // サイズ確認
            if (this.OutShape.Aggregate((now, next) => now * next) != tensor.Size)
                throw new InvalidShapeException($"Expected Output Shape : ({string.Join(",", this.OutShape)})  ,Input Shape :({string.Join(",", tensor.Shape)})");
            this.InShape = tensor.Shape;
            tensor.Shape = this.OutShape;

            return tensor;
        }

        protected override void UnaryBackward()
        {
            this.In.Shape = this.InShape;
        }
    }
}

Tensorクラスへの追加

 ここまで実装したクラスのForwardをTensorから実行できるようにしておきます。

Tensor.Processing.cs
namespace Rein
{
    public partial class Tensor
    {
        public Tensor Detach(){
            return new Detach().Forward(this);
        }

        public Tensor Squeeze(int dim){
            return new Squeeze(dim).Forward(this);
        }
        
        public Tensor  Unsqueeze(int dim = 0){
            return new Unsqueeze(dim).Forward(this);
        }

        public Tensor Reshape(List<int> shape){
            return new Reshape(shape).Forward(this);
        }
    }
}

 これでTensor側でいつでも操作できるようになりました。

終わりに

 今回は、ロス関数とTensorの操作関数を定義しました。ロス関数は他にもクロスエントロピーとかがよく使うと思いますが、現時点では使わなさそうなので必要になったら実装しようと思います。
 次はOptimizerの実装を行います。

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?