6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

chainerのconnectionをいじって新しい層を作る(1)

Last updated at Posted at 2016-10-12

##環境
GPU GTX1070
ubuntu 14.04
chainer 1.14.0
など
##はじめに
chainerで最新のモデルを実装する際には、links/connectionやfunctions/connectionをいじる必要がある。

そこで最も単純なlinear.pyをいじって、新しい層を作ってみる。前回はlinear.pyの中身を確認した。
http://qiita.com/masataka46/items/d66997ac94ec7aa3bcb4

今回はchainer/functions/connection/linear.pyのforward関数をいじって順伝播を改良する。

##改良モデルの概要
元々の全結合3層を以下の図のように改良する。
img_161013_4.png
2層目だけを改良する。この2層目は具体的に以下のように入力側に関して重みを共有する。
forward01.png
この演算処理は以下の図のようになる。
forward02.png
Wは入力側n個で重みを共有するので、W(out_size, in_size / n)となる。この重みを1度の行列積で計算できるよう、in_size / n側をn倍し、in_sizeとする。

この重みと入力側からのデータxとの行列積を求めると、y(batch_size, out_size)が出力される。これにより重みのパラメーター数が減るので、演算は速くなるだろう。そして、性能が若干低下するだろう。また今回、計算を簡略化するためbiasは使わないでおく。
##tain_mnist.pyを修正する
train_mnist.pyも若干変わってくるので修正する。

common_num = 10
out_units = 900
    #chnged model
    def __init__(self, n_in, n_units, n_out):
        super(MLP, self).__init__(
            l1=L.Linear(n_in, n_units),  # first layer
            l2=linear_link.Linear(n_units / common_num, out_units, nobias=True),  # second layer
            l3=L.Linear(out_units, n_out),  # output layer
        )

グローバル変数で定義したcommon_numが、共有する数。また特に意味は無いが3層目のunit数を900に変えた。
##function下のlinear.pyを修正する
chainer/functions/connection/linear.pyのLinearFunctionクラス内forward関数を以下のように修正する。

import cupy
    #modified forward function
    def forward(self, inputs):
        x = _as_mat(inputs[0])
        W = inputs[1]

        #modify to original model
        W_tile = cupy.tile(W.T, (common_num, 1)).astype(W.dtype, copy=False)
        y = x.dot(W_tile).astype(x.dtype, copy=False)

        if len(inputs) == 3:
            b = inputs[2]
            y += b

        return y,

Wをcommon_num倍する際に、cupy(numpy)のtile()を使った。GPU使うのを想定してcupyをimportしているが、使わないならnumpyに変える必要がある。

またcheck_type_forward関数があるとエラーが出るので、コメントアウトする。

    '''
    def check_type_forward(self, in_types):
        n_in = in_types.size()
        type_check.expect(2 <= n_in, n_in <= 3)
        x_type, w_type = in_types[:2]

        type_check.expect(
            x_type.dtype.kind == 'f',
            w_type.dtype.kind == 'f',
            x_type.ndim >= 2,
            w_type.ndim == 2,
            type_check.prod(x_type.shape[1:]) == w_type.shape[1],
        )
        if n_in.eval() == 3:
            b_type = in_types[2]
            type_check.expect(
                b_type.dtype == x_type.dtype,
                b_type.ndim == 1,
                b_type.shape[0] == w_type.shape[0],
            )
    '''

この関数がお節介にもxとWの大きさが対応しているか調べてるみたい。今回、明らかにWだけ小さくしてるので、これが機能するとエラーとなる。

6
3
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?