tensorflowでL2正則化の導入のためにネットワークのパラメータを参照する必要があり,見事にハマったのでメモ.
L2正則化
ここでいう正則化は,過学習を防ぐためにパラメータが発散するのを抑えること.
学習データにフィットするモデルを作る際に,なるべくシンプルなモデルを作った方が未知のデータに対応できる.逆に複雑すぎるモデルを作ると過学習を起こす.こことか見るとわかりやすいかも.
正則化は,損失関数の定義にパラメータそのものの大きさを加えれば実現できる.大きさにパラメータの絶対値を用いるものをL1正則化,2乗の値を用いるものをL2正則化と呼ぶ.
ネットワークのパラメータを参照する
正則化を実現するためには,損失関数の定義式の中からネットワークのパラメータを取ってくる必要があるが,変数名をそのまま指定してもダメらしい.
以下のようにmy_CNNクラスとTrainクラスを定義したとして,
class my_CNN:
with tf.variable_scope("conv1", reuse=False):
filter1 = tf.get_variable("weights", shape=[28, 28])
class Train:
self.cnn = my_CNN()
def loss(self):
tf.nn.l2_loss(ここでfilter1を取ってきたい)
my_CNNクラスで定義したfilter(最適化するパラメータ)を,Trainクラスのlossから参照したい場合,単にtf.nn.l2_loss(self.cnn.filter1)
とやっても取ってこれない.
tf.get_variable()を使う
tf.get_variable()は引数にnameを取り,そのnameの変数が既に存在する場合はそれを返し,存在しない場合は新しく作る.
ただし既に存在する変数を取ってきたい場合は,tf.get_variable()するscopeでreuse=True
とする必要がある.
今回は,my_CNNの中で宣言するときはreuse=False
,Trainの中で宣言するときはreuse=True
とすると参照できる.
class my_CNN:
with tf.variable_scope('conv1', reuse=False):
filter1 = tf.get_variable("weights", shape=[28, 28])
class Train:
self.cnn = my_CNN()
def loss(self):
with tf.variable_scope('conv1', reuse=True):
got_variable = tf.get_variable("weights", shape=[28, 28])
tf.nn.l2_loss(got_variable)