2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

UnityでPytorchライクの機械学習ライブラリを作る。2日目:Tensorクラスの実装その2(四則演算)

Last updated at Posted at 2021-02-17

1. はじめに

 前回はTensorクラスの一部とTensorを加工するIFunctionインターフェースの実装を行いました。今回は、四則演算を行う関数を定義し、それを用いてTensorの四則演算を実装していきたいと思います。

2. 関数クラスについて

 関数クラスはこのライブラリにおいて、Tensorを加工する重要なクラスとなります。関数クラスは全てIFunctionから派生したBaseFunctionを継承するようにしてその下に入力と出力の関係から分類した抽象クラスを実装して最後に具体的な関数のクラスを実装します。下の図のような構成です。

Functions.jpg

UnaryFunctionは一つの入力に対して一つの出力が得られるような関数に対して定義しようと考えています。
SetFunctionは集合に対する関数(minやmax, sum, meanなど)の基底クラスとなる予定です。これらの他にもTensorのShapeを操作したりTensorを二つに分けたりする関数なども考えています。

3. 関数クラスの実装1(BaseFunction)

 ここではIFunctionでForwardやBackwardを呼び出された際の具体的な動作を実装していきます。これによって子クラスでは勾配の伝播のための処理を省くことができます。

BaseFunction.cs

using System;

namespace Rein.Functions
{
    public abstract class BaseFunction: IFunction
    {
        protected Tensor[] Inputs, Outputs;
        protected int UseCount = 0;
        
        protected Func<Tensor[], Tensor[]> FunctionForward;
        protected Action FunctionBackward;
        
        protected string Name;

        public BaseFunction(string name){
            this.Name = name;
        }

        public Tensor[] Forward(params Tensor[] inputs)
        {
            foreach (Tensor input in inputs){
                input.UseCount++;
            }
            this.Inputs = inputs;
            this.Outputs = this.FunctionForward(inputs);
            foreach (Tensor output in this.Outputs){
                output.BackFunction ??= this;
            }
            this.UseCount = this.Outputs.Length;

            return this.Outputs;
        }

        public void Backward()
        {
            this.UseCount--;
            if (this.UseCount != 0)return;
            this.FunctionBackward();
            foreach(Tensor input in this.Inputs){
                input.Backward();
            }
        }
    }
}

4. 関数クラスの実装2(BinaryFunction)

 ここでは、さらに入力を二つ出力が一つであるような関数の処理を求めています。BaseFunctionで定義したFunctionForwardとFunctionBackwardに対して、二つの入力と一つの出力を行うBinaryForwardとBinaryBackwardを追加し、継承したクラスで実装しやすいようLeft(演算子の左), Right(演算子の右), Outの三つを定義しています。

BinaryFunction.cs
namespace Rein.Functions
{
    public abstract class BinaryFunction: BaseFunction
    {
        protected abstract Tensor BinaryForward(Tensor tensor1, Tensor tensor2);
        protected abstract void BinaryBackward();

        protected Tensor Left{
            get{
                return this.Inputs[0];
            }
            set{
                this.Inputs[0] = value;
            }
        }
        protected Tensor Right{
            get{
                return this.Inputs[1];
            }
            set{
                this.Inputs[1] = value;
            }
        }
        protected Tensor Out{
            get{
                return this.Outputs[0];
            }
            set{
                this.Outputs[0] = value;
            }
        }

        public BinaryFunction(string name): base(name)
        {
            this.FunctionForward = (tensor) => new Tensor[1]{ this.BinaryForward(tensor[0], tensor[1]) };
            this.FunctionBackward = this.BinaryBackward;
        }
    }
}

5. 四則演算クラスの実装

 正直四則演算のクラス同士にそこまで差異はありません。一つ実装できたら他も簡単に実装できると思います。二項演算においては二つの入力を持つ関数を考えます。

O = f(L, R)

最終的な出力を$E$とすると$E.Backward()$を行った際にはそれぞれのTensorに対する$E$の勾配が計算されていくので

\frac{\partial E}{\partial L}=\frac{\partial E}{\partial O}\frac{\partial O}{\partial L}

 しかし、実際は$L$を入力とするのは$O$だけとは限らないので、$O$に対して添字$i=(1, 2, 3, ..., n)$がつき、

\frac{\partial E}{\partial L}=\sum_{i=1}^{n}\frac{\partial E}{\partial O_i}\frac{\partial O_i}{\partial L}

 よって二項演算において実際に行う計算は

L.Backward\leftarrow L.Backward + O.Backward \times \frac{\partial O}{\partial L}

5.1 Add(和)

 足し算の場合

O = f(L, R) = L + R

 実際にはTensorの配列に対して行うので、

O_i = L_i + R_i

 さらに微分は

\frac{\partial O_i}{\partial L_i}=1\\
\frac{\partial O_i}{\partial R_i}=1

 したがって実装は

Add.cs
using System.Linq;
using R = System.Double;

namespace Rein.Functions.Arithmetic{
    public class Add: BinaryFunction{

        public Add():base("Add"){

        }

        protected override Tensor BinaryForward(Tensor left, Tensor right){
            R[] data = new R[left.Size];
            
            for(int i=0; i < left.Size; i++){
                data[i] = left.Data[i] + right.Data[i];
            }

            return new Tensor(
                data,
                left.Shape
            );
        }

        protected override void BinaryBackward(){
            for(int i = 0; i < this.Left.Size; i++){
                this.Left.Grad[i] += this.Out.Grad[i];
                this.Right.Grad[i] += this.Out.Grad[i];
            }
        }
    }
}

5.2 Sub(差)

 引き算の場合

O_i = L_i - R_i

 微分は

\frac{\partial O_i}{\partial L_i}=1\\
\frac{\partial O_i}{\partial R_i}=-1

 実装は

Sub.cs
using System.Linq;
using R = System.Double;

namespace Rein.Functions.Arithmetic{
    public class Sub: BinaryFunction{

        public Sub(): base("Sub"){

        }
        protected override Tensor BinaryForward(Tensor left, Tensor right){
            R[] data = new R[left.Size];
            
            for(int i=0; i < left.Size; i++){
                data[i] = left.Data[i] - right.Data[i];
            }

            return new Tensor(
                data,
                left.Shape
            );
        }

        protected override void BinaryBackward(){
            for(int i = 0; i < this.Left.Size; i++){
                this.Left.Grad[i] += this.Out.Grad[i];
                this.Right.Grad[i] -= this.Out.Grad[i];
            }
        }
    }
}using System.Linq;
using R = System.Double;

namespace Rein.Functions.Arithmetic{
    public class Sub: BinaryFunction{

        public Sub(): base("Sub"){

        }
        protected override Tensor BinaryForward(Tensor left, Tensor right){
            R[] data = new R[left.Size];
            
            for(int i=0; i < left.Size; i++){
                data[i] = left.Data[i] - right.Data[i];
            }

            return new Tensor(
                data,
                left.Shape
            );
        }

        protected override void BinaryBackward(){
            for(int i = 0; i < this.Left.Size; i++){
                this.Left.Grad[i] += this.Out.Grad[i];
                this.Right.Grad[i] -= this.Out.Grad[i];
            }
        }
    }
}

5.3 Mul(積)

 掛け算の場合

O_i = L_i * R_i

 微分は

\frac{\partial O_i}{\partial L_i}=R_i\\
\frac{\partial O_i}{\partial R_i}=L_i

 よって実装は

Mul.cs
using System.Linq;
using R = System.Double;

namespace Rein.Functions.Arithmetic{
    public class Mul: BinaryFunction{

        public Mul(): base("Mul"){

        }
        protected override Tensor BinaryForward(Tensor left, Tensor right){
            R[] data = new R[left.Size];
            
            for(int i=0; i < left.Size; i++){
                data[i] = left.Data[i] * right.Data[i];
            }

            return new Tensor(
                data,
                left.Shape
            );
        }

        protected override void BinaryBackward(){
            for(int i = 0; i < this.Left.Size; i++){
                this.Left.Grad[i] += this.Out.Grad[i] * this.Right.Data[i];
                this.Right.Grad[i] += this.Out.Grad[i] * this.Left.Data[i];
            }
        }
    }
}

5.4 Div(商)

 割り算の場合

O_i = L_i / R_i

 微分は

\frac{\partial O_i}{\partial L_i}=\frac{1}{R_i}\\
\frac{\partial O_i}{\partial R_i}=-\frac{L_i}{R_i^2}=-\frac{O_i}{R_i}

 $O_i$については$Forward$で計算しているので上のようにすることで計算回数を抑えられる。
 これを用いて実装を行うと

Div.cs
using System.Linq;
using R = System.Double;

namespace Rein.Functions.Arithmetic{
    public class Div: BinaryFunction{

        public Div(): base("Div"){

        }
        protected override Tensor BinaryForward(Tensor left, Tensor right){
            R[] data = new R[left.Size];
            
            for(int i=0; i < left.Size; i++){
                data[i] = left.Data[i] / right.Data[i];
            }

            return new Tensor(
                data,
                left.Shape
            );
        }

        protected override void BinaryBackward(){
            for(int i = 0; i < this.Left.Size; i++){
                this.Left.Grad[i] += this.Out.Grad[i] / this.Right.Data[i];
                this.Right.Grad[i] -= this.Out.Grad[i] * this.Out.Data[i] / this.Right.Data[i];
            }
        }
    }
}

6. Tensorクラスの実装(四則演算)

 Tensorクラスは割と処理の内容が多くなってくると思うので前回使用したファイルとは別のファイルに四則演算を実装します。(そのために前回partialでTensorクラスを実装しました)
 演算子の実装では、それぞれの関数のコンストラクタを呼び出すことで計算グラフを構築しながら演算を行うことができるようになりました。

Tensor.Operator
using Rein.Functions.Arithmetic;
using Rein.Utils.Exceptions;

namespace Rein
{
    public partial class Tensor
    {
        // 演算子のオーバーロード
        public static Tensor operator +(Tensor tensor1, Tensor tensor2)
        {
            return new Add().Forward(tensor1, tensor2);
        }

        public static Tensor operator -(Tensor tensor1, Tensor tensor2)
        {
            return new Sub().Forward(tensor1, tensor2);
        }

        public static Tensor operator -(Tensor tensor)
        {
            return null;
        }

        public static Tensor operator *(Tensor tensor1, Tensor tensor2)
        {
            return new Mul().Forward(tensor1, tensor2);
        }

        public static Tensor operator /(Tensor tensor1, Tensor tensor2)
        {
            return new Div().Forward(tensor1, tensor2);
        }

        public static implicit operator Tensor(Tensor[] tensor1)
        {
            if(!(tensor1.Length == 1))throw new InvalidLengthException();
            return tensor1[0];
        }
    }
}

 最後のTensor[]$\rightarrow$Tensorへの変換は実装するか悩みました。しかしこれが無いと出力一つの関数(minやmaxなど)を使用するたびにインデックス0を指定しないといけなくなるので(IFunctionではTensor[]で数値のやりとりを行うため)利便性のためにこれを追加することにしました。

8. コード

 ここまでのコードはhttps://github.com/aokyut/Rein/tree/v0.0.2で公開して居ます。現時点では使い物になりませんが、続きを実装してみたい方やこれまでの実装を確認したい方は見てみてください。

7. 終わりに

 ということで今回は関数の基底クラスの定義とTensorクラスの四則演算の定義を行いました。機械学習の実用に使えるようなものではありませんがTensor同士で計算できるようにはなりました。次は多次元配列の演算として多用されるドット演算や単項演算あたりを定義したいと思います。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?