3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

chainerのconnectionをいじって新しい層を作る(2)

Last updated at Posted at 2016-10-12

環境

GPU GTX1070
ubuntu 14.04
chainer 1.14.0
など

はじめに

chainerで最新のモデルを実装する際には、links/connectionやfunctions/connectionをいじる必要がある。そこで最も単純なlinear.pyをいじって、新しい層を作ってみる。

前回はchainer/functions/connection/linear.pyのforward関数をいじって順伝播を改良した。
http://qiita.com/masataka46/items/1a5d6cbd49279aeaf734

今回はbackward関数をいじって誤差逆伝播を改良する。

誤差逆伝播の計算

誤差逆伝播の演算は以下の図のようになる。
backward01.png
gxを求めるのに、Wのin_size / n側をn倍すれば計算が楽になる。またgWを求める演算を図式化すると以下のようになる。
backward02.png
今度はgxとxとの行列積で出力されたものをin_size軸側に圧縮・足しあわせる必要がある。この際、cupy.sum()を使う。

linear.pyの変更

chainer/functions/connection/linear.py内のbackward関数を以下のように修正する。

    def backward(self, inputs, grad_outputs):

        x = _as_mat(inputs[0])
        W = inputs[1]
        gy = grad_outputs[0]
        
        W_tile = cupy.tile(W, (1, common_num)).astype(W.dtype, copy=False)

        gx = gy.dot(W_tile).astype(gy.dtype, copy=False)
        gW_wide = gy.T.dot(x).astype(W.dtype, copy=False)
        gW_cube = gW_wide.reshape(len(gW_wide), common_num, -1).astype(W.dtype, copy=False)
        gW = gW_cube.sum(axis=1).astype(W.dtype, copy=False)

        if len(inputs) == 3:
            gb = gy.sum(0)
            return gx, gW, gb
        else:
            return gx, gW

行列積の結果(gW_wide)に対して、まず次元を増やす(gW_cube)。次にその次元に関して足しあわせて(sum関数)いる。

修正モデルの学習結果

修正モデルを走らせて、accuracyと処理速度を元のモデルと比較する。

python train_mnist3.py -g=0 -e=50
GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 50

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy
1           0.204286    0.10436               0.936834       0.9683                    
2           0.0795502   0.0783082             0.975266       0.974
・・・
49          0.00331749  0.124306              0.999166       0.9843                    
50          0.00580243  0.136195              0.998583       0.9836    

50回の学習で1分30秒、accuracyは0.968から0.984に上昇した。元のモデルだと学習時間1分54秒でaccuracyは0.971から0.985である。よって、予想通り学習時間が減って、accuracyはわずかに減少している。

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?