LoginSignup
29
31

More than 5 years have passed since last update.

[TF]Optimizerで更新する変数を指定する方法

Posted at

Deep Learningでモデルを作成して学習するときに、学習するパラメータを指定したい時があります。
例えば2つのネットワークを交互に学習している時に、一方のネットワークのパラメータを更新したくない時とかです。

学習するパラメータを指定する方法はいくつかあります。

1.Variableの引数のtrainableをFalseにする

python
x = tf.Variable(tf.constant([2.]), name='x', trainable=False)

2.Optimizerに更新する変数のリストを渡す

python
opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f, var_list=[w,b])

2番の方がコードの変更が楽です。

実行結果

なんにもしない時

下記のようなコードを実行すると変数w, b, x, y_が更新されてしまいます。
(下記のコード自体実用性は殆ど無いですし、通常はx, y_をplaceholderにするので問題にはなりませんが、例ということでご容赦ください。)

python
import tensorflow as tf
import numpy as np

w = tf.Variable(tf.constant([3.]), name='w')
b = tf.Variable(tf.constant([1.]), name='b')
x = tf.Variable(tf.constant([2.]), name='x')
y_ = tf.Variable(tf.constant([5.]), name='y_')

p = w*x
y = p+b
s = -y
t = s +y_
f = t*t

gx, gb, gw, gp, gy, gy_,gs, gt, gf = tf.gradients(f, [x, b, w, p, y, y_,s, t, f])

init = tf.initialize_all_variables()

opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f)

with tf.Session() as sess:
    sess.run(init)
    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f' % (sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))
    print '---------- run GradientDescentOptimizer ----------'
    sess.run(train)

    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f'%(sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))

実行結果
sess.run(train)のあとでxとy_の値も更新されています。

x:2.00, w:3.00, b:1.00
p:6.00, y:7.00, y_:5.00
s:-7.00, t:-2.00, f:4.00
---------- gradient ----------
gx:12.00, gw:8.00, gb: 4.00
gp:4.00, gy:4.00, gy_:-4.00
gs:-4.00, gt:-4.00, gf:1.00
---------- run GradientDescentOptimizer ----------
x:-10.00, w:-5.00, b:-3.00
p:50.00, y:47.00, y_:9.00
s:-47.00, t:-38.00, f:1444.00
---------- gradient ----------
gx:-380.00, gw:-760.00, gb: 76.00
gp:76.00, gy:76.00, gy_:-76.00
gs:-76.00, gt:-76.00, gf:1.00

trainableを指定した場合

python
import tensorflow as tf
import numpy as np

w = tf.Variable(tf.constant([3.]), name='w')
b = tf.Variable(tf.constant([1.]), name='b')
x = tf.Variable(tf.constant([2.]), name='x', trainable=False)
y_ = tf.Variable(tf.constant([5.]), name='y_', trainable=False)

p = w*x
y = p+b
s = -y
t = s +y_
f = t*t

gx, gb, gw, gp, gy, gy_,gs, gt, gf = tf.gradients(f, [x, b, w, p, y, y_,s, t, f])

init = tf.initialize_all_variables()

opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f)

with tf.Session() as sess:
    sess.run(init)
    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f' % (sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))
    print '---------- run GradientDescentOptimizer ----------'
    sess.run(train)

    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f'%(sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))

実行結果
sess.run(train)のあとでもxとy_の値も変わっていないのがわかります。

x:2.00, w:3.00, b:1.00
p:6.00, y:7.00, y_:5.00
s:-7.00, t:-2.00, f:4.00
---------- gradient ----------
gx:12.00, gw:8.00, gb: 4.00
gp:4.00, gy:4.00, gy_:-4.00
gs:-4.00, gt:-4.00, gf:1.00
---------- run GradientDescentOptimizer ----------
x:2.00, w:-5.00, b:-3.00
p:-10.00, y:-13.00, y_:5.00
s:13.00, t:18.00, f:324.00
---------- gradient ----------
gx:180.00, gw:-72.00, gb: -36.00
gp:-36.00, gy:-36.00, gy_:36.00
gs:36.00, gt:36.00, gf:1.00

Optimizerに更新する変数のリストを渡す場合

Optimizerに変数のリストを渡す場合は、minimizeの引数としてvar_listを渡します。

python
import tensorflow as tf
import numpy as np

w = tf.Variable(tf.constant([3.]), name='w')
b = tf.Variable(tf.constant([1.]), name='b')
x = tf.Variable(tf.constant([2.]), name='x')
y_ = tf.Variable(tf.constant([5.]), name='y_')

p = w*x
y = p+b
s = -y
t = s +y_
f = t*t


gx, gb, gw, gp, gy, gy_,gs, gt, gf = tf.gradients(f, [x, b, w, p, y, y_,s, t, f])

init = tf.initialize_all_variables()

opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f, var_list=[w,b])

with tf.Session() as sess:
    sess.run(init)
    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f' % (sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))
    print '---------- run GradientDescentOptimizer ----------'
    sess.run(train)

    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f'%(sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))

実行結果

x:2.00, w:3.00, b:1.00
p:6.00, y:7.00, y_:5.00
s:-7.00, t:-2.00, f:4.00
---------- gradient ----------
gx:12.00, gw:8.00, gb: 4.00
gp:4.00, gy:4.00, gy_:-4.00
gs:-4.00, gt:-4.00, gf:1.00
---------- run GradientDescentOptimizer ----------
x:2.00, w:-5.00, b:-3.00
p:-10.00, y:-13.00, y_:5.00
s:13.00, t:18.00, f:324.00
---------- gradient ----------
gx:180.00, gw:-72.00, gb: -36.00
gp:-36.00, gy:-36.00, gy_:36.00
gs:36.00, gt:36.00, gf:1.00

わざわざvar_listに変数を列挙するのが面倒な場合はscopeを使うと少し楽になります。
手順は下記のようにします。

  1. 変数を宣言するときにscopeを使う。
  2. get_collectionでscopeを指定して、そのscopeの変数リストを取得する
python
import tensorflow as tf
import numpy as np

with tf.variable_scope("params"):
    w = tf.Variable(tf.constant([3.]), name='w')
    b = tf.Variable(tf.constant([1.]), name='b')

with tf.variable_scope("input"):
    x = tf.Variable(tf.constant([2.]), name='x')
    y_ = tf.Variable(tf.constant([5.]), name='y_')

with tf.variable_scope("intermediate"):
    p = w*x
    y = p+b
    s = -y
    t = s +y_
    f = t*t    


gx, gb, gw, gp, gy, gy_,gs, gt, gf = tf.gradients(f, [x, b, w, p, y, y_,s, t, f])

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="params")
print 'train_vars'
for v in train_vars:
    print v.name

init = tf.initialize_all_variables()

opt = tf.train.GradientDescentOptimizer(1.0)
train = opt.minimize(f, var_list=train_vars)

with tf.Session() as sess:
    sess.run(init)
    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f' % (sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))
    print '---------- run GradientDescentOptimizer ----------'
    sess.run(train)

    print 'x:%.2f, w:%.2f, b:%.2f' % (sess.run(x), sess.run(w), sess.run(b))
    print 'p:%.2f, y:%.2f, y_:%.2f'% (sess.run(p), sess.run(y), sess.run(y_))
    print 's:%.2f, t:%.2f, f:%.2f'%(sess.run(s), sess.run(t), sess.run(f))

    print '---------- gradient ----------'
    print 'gx:%.2f, gw:%.2f, gb: %.2f' % (sess.run(gx), sess.run(gw), sess.run(gb))
    print 'gp:%.2f, gy:%.2f, gy_:%.2f' %(sess.run(gp), sess.run(gy), sess.run(gy_))
    print 'gs:%.2f, gt:%.2f, gf:%.2f' %(sess.run(gs), sess.run(gt), sess.run(gf))

実行結果

train_vars
params/w:0
params/b:0
x:2.00, w:3.00, b:1.00
p:6.00, y:7.00, y_:5.00
s:-7.00, t:-2.00, f:4.00
---------- gradient ----------
gx:12.00, gw:8.00, gb: 4.00
gp:4.00, gy:4.00, gy_:-4.00
gs:-4.00, gt:-4.00, gf:1.00
---------- run GradientDescentOptimizer ----------
x:2.00, w:-5.00, b:-3.00
p:-10.00, y:-13.00, y_:5.00
s:13.00, t:18.00, f:324.00
---------- gradient ----------
gx:180.00, gw:-72.00, gb: -36.00
gp:-36.00, gy:-36.00, gy_:36.00
gs:36.00, gt:36.00, gf:1.00
29
31
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29
31