0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

UnityでPytorchライクの機械学習ライブラリを作る。 4日目:行列積の勾配処理

Last updated at Posted at 2021-02-20

今回すること

 今回は行列積の勾配処理を実装していきますが、その前に行列の場合はどのようにして勾配を計算するか考えていきます。
 まず$N\times L$の行列$A$と$L\times M$の行列$B$の積を$C=AB$とおきます。
 この時$C$の$i$行$j$列目の要素$C_{ij}$は以下のように計算されます。

C_{ij} = \sum_{k=1}^{L}A_{ik}B_{kj}

 適当な$s$を用いて$A_{is}$で両辺を微分すると

\frac{\partial C_{ij}}{\partial A_{is}}=B_{sj}

 よって最終的な関数の出力を$E$とすると、

\frac{\partial E}{\partial A_{is}}=\sum_{j=1}^{M}\frac{\partial E}{\partial C_{ij}}B_{sj}

 これは実装上では

A.Grad[i, s]=\sum_{j=1}^{M}C.Grad[i, j]\cdot B[s, j]

 ここで、転置$B[s, j]=B^T[j, s]$を代入すると

A.Grad[i, s]=\sum_{j=1}^{M}C.Grad[i, j]\cdot B^T[j, s]

 よって$A.Grad=C.Grad\cdot B^T$

 同様に$B$について考えると

B.Grad[s, j]=\sum_{i=1}^{N}C.Grad[i, j]\cdot A^T[s, i]

 となって$B.Grad=A^T\cdot C.Grad$
 これらを計算すれば良い。

実装

 上のAとBで計算の仕方が異なるので別のループで計算する。

Dot.cs
namespace Rein.Functions{
    public class Dot: BinaryFunction{
        public Dot():base("Dot"){

        }

// ...

        protected override void BinaryBackward()
        {
            int N = this.Left.Size / this.Left.Shape.Last();
            int M= this.Right.Shape[1];
            int L = this.Right.Shape[0];

            // A(左側)
            for (int i=0; i < N; i++){
                for(int j=0; j < L; j++){
                    R sum_left = 0;
                    for(int k=0; k < M; k++){
                        sum_left += this.Out.Grad[i * M + k] * this.Right.Data[j * M + k];
                    }
                    this.Left.Grad[i * L + j] += sum_left;
                }
            }

            // B(右側)
            for (int k=0; k < N; k++){
                for (int i=0; i < L; i++){
                    for (int j=0; j < M; j++){
                        this.Right.Grad[i * M + j] += this.Out.Grad[k * M + j] * this.Left.Data[k * L + i];
                    }
                }
            }
        }
    }
}

 A(左側)の処理ですが運のいいことにそのまま計算すれば配列内の移動幅が少なく済むのでループ交換は必要ありません。Bの処理はkを最も上のループにすれば移動幅が少なく済みます。
 Aで一度sumに代入している理由ですが、一番下のkのループで毎回配列にアクセスすると速度が落ちるので一つのに入れておき最後に代入することで速度が上がります。

 あとこれにループアンローリングと代入を減らした物が最終的な実装となります。

Dot.cs
// ...

        protected override void BinaryBackward()
        {
            int N = this.Left.Size / this.Left.Shape.Last();
            int M= this.Right.Shape[1];
            int L = this.Right.Shape[0];

            int i, j, k, k1, k2, k3;

            for (i=0; i < N; i++){

                for(j=0; j < L - 3; j++){
                    R sum = 0;
                    for(k=0; k < M; k+=4){
                        sum += this.Out.Grad[i * M + k] * this.Right.Data[j * M + k]
                            + this.Out.Grad[i * M + k + 1] * this.Right.Data[j * M + k + 1]
                            + this.Out.Grad[i * M + k + 2] * this.Right.Data[j * M + k + 2]
                            + this.Out.Grad[i * M + k + 3] * this.Right.Data[j * M + k + 3];
                    }

                    for(; k < M; k++){
                        sum += this.Out.Grad[i * M + k] * this.Right.Data[j * M + k];
                    }
                    this.Left.Grad[i * L + j] += sum;
                }
            }

            for (k=0; k < N - 3; k+=4){
                k1 = k + 1;
                k2 = k + 2;
                k3 = k + 3;
                for (i=0; i < L; i++){
                    for (j=0; j < M; j++){
                        this.Right.Grad[i * M + j] += this.Out.Grad[k * M + j] * this.Left.Data[k * L + i]
                        + this.Out.Grad[k1 * M + j] * this.Left.Data[k1 * L + i]
                        + this.Out.Grad[k2 * M + j] * this.Left.Data[k2 * L + i]
                        + this.Out.Grad[k3 * M + j] * this.Left.Data[k3 * L + i];
                    }
                }
            }

            for (; k < N; k++){
                for (i=0; i < L; i++){
                    for (j=0; j < M; j++){
                        this.Right.Grad[i * M + j] += this.Out.Grad[k * M + j] * this.Left.Data[k * L + i];
                    }
                }
            }
        }

終わりに

 ということで今回は行列積の勾配処理を実装しました。思いの外時間がかかってしまいましたが、重要な部分なのでまあいいでしょう。次こそは単項演算の実装をします。というか実装自体は終わっていてそれを文章にまとめるだけなのですぐに終わると思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?