LoginSignup
37
22

More than 5 years have passed since last update.

Chainer v2でパラメータごとに更新ルールの設定を行う

Posted at

はじめに

Chainerではv2からUpdateRuleが導入され、UpdateRuleインスタンスの操作を行うことでパラメータごとの更新ルールを設定することができるようになりました。
例えば学習率を変えたり、一部のパラメータの更新を抑制することができます。

パラメータとは

本投稿でのパラメータはchainer.Parameterインスタンスを指します。
chainer.Parameterchainer.Variableを継承したクラスで、chainer.Linkが持つパラメータを保持する目的で使われます。
例えばchainer.functions.Convolution2DWbという2つのパラメータを持ちます。

UpdateRule

chainer.UpdateRuleはパラメータをどのように更新するかを定めたクラスです。
SGD等の更新アルゴリズムに対応した派生クラスが存在します。
UpdateRuleは以下の属性を持ちます。

  • enabled(bool): パラメータを更新するかどうかを表すフラグです。
  • hyperparam(chainer.Hyperparameter): 更新アルゴリズムに関わるハイパーパラメータを保持します。
  • t(int): 更新ルールによって更新された回数を保持します。

enabledやhyperparamを操作することでパラメータの更新を止めたり学習率を変更したりすることができます。

UpdateRuleが生成されるタイミング

各パラメータが持つUpdateRuleインスタンスはchainer.Optimizerインスタンスのsetup()を呼び出したタイミングで生成されます。

以下のようなニューラルネットワークを構築したとします。

class MLP(chainer.Chain):
    def __init__(self):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(2, 2)
            self.l2 = L.Linear(2, 1)

    def __call__(self, x):
        h = self.l1(x)
        h = self.l2(h)
        return h

パラメータの更新を止める

パラメータの更新は、パラメータ単位、Link単位で止めることができます。

パラメータ単位で指定する

特定のパラメータの更新を行わないようにするには、update_rule.enabledをFalseにすればよいです。
例:

net.l1.W.update_rule.enabled = False

Link単位で指定する

Linkの更新を行わないようにするには、disable_update()を呼べばよいです。
逆にLinkが持つ全てのパラメータを更新するためにはenable_updateを呼び出してください。

例:

net.l1.disable_update()

ハイパーパラメータを変更する

学習率等のハイパーパラメータはhyperparamの属性を操作することで変更できます。

例:

net.l1.W.update_rule.hyperparam.lr = 1.0

hook functionの追加

update_rule.add_hookを呼び出すことでchainer.optimizer.WeightDecay等のhook functionをパラメータ単位で設定することができます。

例:

net.l1.W.update_rule.add_hook(chainer.optimizer.WeightDecay(0.0001))

試してみる

例として一部のパラメータの学習率を大きくし、別のパラメータの更新を止めてみます。

# -*- coding: utf-8 -*-

import numpy as np

import chainer
from chainer import functions as F
from chainer import links as L
from chainer import optimizers


class MLP(chainer.Chain):
    def __init__(self):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(2, 2)
            self.l2 = L.Linear(2, 1)

    def __call__(self, x):
        h = self.l1(x)
        h = self.l2(h)
        return h


net = MLP()
optimizer = optimizers.SGD(lr=0.1)

# setupを呼び出すとUpdateRuleが生成される
optimizer.setup(net)

net.l1.W.update_rule.hyperparam.lr = 10.0
net.l1.b.update_rule.enabled = False

x = np.asarray([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
t = np.asarray([[0], [1], [1], [0]], dtype=np.int32)

y = net(x)

print('before')
print('l1.W')
print(net.l1.W.data)
print('l1.b')
print(net.l1.b.data)
print('l2.W')
print(net.l2.W.data)
print('l2.b')
print(net.l2.b.data)

loss = F.sigmoid_cross_entropy(y, t)
net.cleargrads()
loss.backward()
optimizer.update()

print('after')
print('l1.W')
print(net.l1.W.data)
print('l1.b')
print(net.l1.b.data)
print('l2.W')
print(net.l2.W.data)
print('l2.b')
print(net.l2.b.data)

実行結果は以下のようになりました。
l1.Wの変更量がl2.Wの変更量よりもかなり大きく、l1.bは変更されていないことがわかります。

before
l1.W
[[ 0.0049778  -0.16282777]
 [-0.92988533  0.2546134 ]]
l1.b
[ 0.  0.]
l2.W
[[-0.45893994 -1.21258962]]
l2.b
[ 0.]
after
l1.W
[[ 0.53748596  0.01032409]
 [ 0.47708291  0.71210718]]
l1.b
[ 0.  0.]
l2.W
[[-0.45838338 -1.20276082]]
l2.b
[-0.01014706]
37
22
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
37
22