##環境
GPU GTX1070
ubuntu 14.04
chainer 1.14.0
など
##はじめに
chainerで最新のモデルを実装する際には、links/connectionやfunctions/connectionをいじる必要がある。
そこで最も単純なlinear.pyをいじって、新しい層を作ってみる。前回はlinear.pyの中身を確認した。
http://qiita.com/masataka46/items/d66997ac94ec7aa3bcb4
今回はchainer/functions/connection/linear.py
のforward関数をいじって順伝播を改良する。
##改良モデルの概要
元々の全結合3層を以下の図のように改良する。
2層目だけを改良する。この2層目は具体的に以下のように入力側に関して重みを共有する。
この演算処理は以下の図のようになる。
Wは入力側n個で重みを共有するので、W(out_size, in_size / n)となる。この重みを1度の行列積で計算できるよう、in_size / n側をn倍し、in_sizeとする。
この重みと入力側からのデータxとの行列積を求めると、y(batch_size, out_size)が出力される。これにより重みのパラメーター数が減るので、演算は速くなるだろう。そして、性能が若干低下するだろう。また今回、計算を簡略化するためbiasは使わないでおく。
##tain_mnist.pyを修正する
train_mnist.pyも若干変わってくるので修正する。
common_num = 10
out_units = 900
#chnged model
def __init__(self, n_in, n_units, n_out):
super(MLP, self).__init__(
l1=L.Linear(n_in, n_units), # first layer
l2=linear_link.Linear(n_units / common_num, out_units, nobias=True), # second layer
l3=L.Linear(out_units, n_out), # output layer
)
グローバル変数で定義したcommon_numが、共有する数。また特に意味は無いが3層目のunit数を900に変えた。
##function下のlinear.pyを修正する
chainer/functions/connection/linear.py
のLinearFunctionクラス内forward関数を以下のように修正する。
import cupy
#modified forward function
def forward(self, inputs):
x = _as_mat(inputs[0])
W = inputs[1]
#modify to original model
W_tile = cupy.tile(W.T, (common_num, 1)).astype(W.dtype, copy=False)
y = x.dot(W_tile).astype(x.dtype, copy=False)
if len(inputs) == 3:
b = inputs[2]
y += b
return y,
Wをcommon_num倍する際に、cupy(numpy)のtile()を使った。GPU使うのを想定してcupyをimportしているが、使わないならnumpyに変える必要がある。
またcheck_type_forward関数があるとエラーが出るので、コメントアウトする。
'''
def check_type_forward(self, in_types):
n_in = in_types.size()
type_check.expect(2 <= n_in, n_in <= 3)
x_type, w_type = in_types[:2]
type_check.expect(
x_type.dtype.kind == 'f',
w_type.dtype.kind == 'f',
x_type.ndim >= 2,
w_type.ndim == 2,
type_check.prod(x_type.shape[1:]) == w_type.shape[1],
)
if n_in.eval() == 3:
b_type = in_types[2]
type_check.expect(
b_type.dtype == x_type.dtype,
b_type.ndim == 1,
b_type.shape[0] == w_type.shape[0],
)
'''
この関数がお節介にもxとWの大きさが対応しているか調べてるみたい。今回、明らかにWだけ小さくしてるので、これが機能するとエラーとなる。