2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

[tensorflow]tf.get_variable()で変数を参照する

Posted at

tensorflowでL2正則化の導入のためにネットワークのパラメータを参照する必要があり,見事にハマったのでメモ.

L2正則化

ここでいう正則化は,過学習を防ぐためにパラメータが発散するのを抑えること.
学習データにフィットするモデルを作る際に,なるべくシンプルなモデルを作った方が未知のデータに対応できる.逆に複雑すぎるモデルを作ると過学習を起こす.こことか見るとわかりやすいかも.

正則化は,損失関数の定義にパラメータそのものの大きさを加えれば実現できる.大きさにパラメータの絶対値を用いるものをL1正則化,2乗の値を用いるものをL2正則化と呼ぶ.

ネットワークのパラメータを参照する

正則化を実現するためには,損失関数の定義式の中からネットワークのパラメータを取ってくる必要があるが,変数名をそのまま指定してもダメらしい.

以下のようにmy_CNNクラスとTrainクラスを定義したとして,

class my_CNN:
  with tf.variable_scope("conv1", reuse=False):
    filter1 = tf.get_variable("weights", shape=[28, 28])

class Train:
  self.cnn = my_CNN()
  def loss(self):
    tf.nn.l2_loss(ここでfilter1を取ってきたい)

my_CNNクラスで定義したfilter(最適化するパラメータ)を,Trainクラスのlossから参照したい場合,単にtf.nn.l2_loss(self.cnn.filter1)とやっても取ってこれない.

tf.get_variable()を使う

tf.get_variable()は引数にnameを取り,そのnameの変数が既に存在する場合はそれを返し,存在しない場合は新しく作る.
ただし既に存在する変数を取ってきたい場合は,tf.get_variable()するscopeでreuse=Trueとする必要がある.
今回は,my_CNNの中で宣言するときはreuse=False,Trainの中で宣言するときはreuse=Trueとすると参照できる.

class my_CNN:
  with tf.variable_scope('conv1', reuse=False):
    filter1 = tf.get_variable("weights", shape=[28, 28])

class Train:
  self.cnn = my_CNN()
  def loss(self):
    with tf.variable_scope('conv1', reuse=True):
      got_variable = tf.get_variable("weights", shape=[28, 28])
    tf.nn.l2_loss(got_variable)
2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?