Why not login to Qiita and try out its useful features?

We'll deliver articles that match you.

You can read useful information later.

4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Pytorchでのカスタムオプティマイザー制作方法をまとめた

Last updated at Posted at 2023-08-17

Pytorchでtorch.nn.Moduleを継承してカスタムレイヤーを制作する記事は日本語記事でもかなりありましたが、最適手法をtorch.optim.Optimizerを継承して制作している記事が日本語記事では見当たらなかったので今回記事を書くことにしました。この記事は主に、Writing Your Own Optimizers in PyTorchや、Custom Optimizers in Pytorchといった記事を参考にしています。

この記事の内容は製作中のAttention from scratchというリポジトリから内容を抜粋したものになります。よろしければそちらもご覧ください。

torch.optim.Optimizer継承時の注意点

レイヤーの制作の際はnn.Moduleの継承を行いました。Optimizerはtorch.optim.Optimizerの継承を行います。nn.Moduleでは主にinitとforwardのオーバーライドを行いました。optim.optimizerではinitとstepのオーバーライドを行います。
また、最適化の際はtorch.optim.Optimizerクラスを継承しているのでoptimizer.zero_grad()のようなメソッドも使えることに注意しておきましょう。

  • initの引数
    params: 最適化するパラメータ-をまとめるもの(iterableでなければならない)、model.parameters()で渡すものだと考えてください。
    他のパラメーター: 学習率やOptimizerによっては他のパラメーターがあると思いますが、それにあたります。

  • initの実装上の注意点
    initメソッドではパラメータ-が正当なものかを判別する例外処理を書く必要があることもあります。
    Optimizerクラスを継承する時にはparamsに加えてlrのなどのパラメーターがまとめられたものを辞書型で渡して継承しなければいけません。

  • stepの引数
    closure: Conjugate GradientやLBFGSのような最適化アルゴリズムでは関数を何度も再評価するので必要になるらしいですが、おそらく滅多に使うことはないかと思われます。

  • stepの実装上の注意点
    stepを実行した時点でParametersの勾配(grad)は計算されているものとします。
    計算済みのgradを使って最適化の処理を書くところがstepと考えて良いでしょう。
    optimizerが処理するパラメーターはself.param_groupsにOptimizerクラスを継承した際に入れられています。
    params_groupsは辞書を要素としたリストであり、モデルのパラメーターを個別の要素に分割する方法を提供します。
    例えば、異なる学習率を使用してネットワークの別々のレイヤーをトレーニングする場合などはこれが使われます。
    Pytorchの公式ドキュメントでは以下のように使用例があります。
    torch.optim.Optimizer
    optim.SGD([
                  {'params': model.base.parameters()},
                  {'params': model.classifier.parameters(), 'lr': 1e-3}
              ], lr=1e-2, momentum=0.9)
    

    params_groupの要素となっている辞書は, {"params":, "lr":, "momentum":, }のように、
    Optimizerの継承の時に使用したパラメーターなどを保存しています。

今回はMomentumSGDの実装を行います。
MomentumSGDは以下の式で定義されます。
$$
\text{w}^{t+1} = \text{w}^{t} - η\dfrac{\partial E(\text{w}^t)}{\partial \text{w}^t} + \alpha Δ\text{w}^t
$$
$$
Δ\text{w}^{t+1} = - η\dfrac{\partial E(\text{w}^t)}{\partial \text{w}^t} + \alpha Δ\text{w}^t
$$
デフォルトで$η=0.001, \alpha = 0.9$として実装を行います。
さて、これを実装したものが以下のものになります。
以下のコードでは$η, \alpha$がそれぞれlr(learning rate), momentumという名前に変化していることに注意してください。
また、上の理論式の偏微分部分の値はparameter.grad.dataで呼び出しています。

MomentumSGD.py
import torch
from torch import optim
class MomentumSGD(optim.Optimizer):
    def __init__(self, params, lr = 0.001, momentum = 0.9) -> None:
        if lr < 0:
            raise ValueError(f"Invalid learning rate: lr should be >= 0")
        if momentum < 0:
            raise ValueError(f"Invalid momentum rate: momentum should be >= 0")
        defaults = dict(lr = lr, momentum = momentum)
        super(MomentumSGD, self).__init__(params, defaults)
        self.state = dict()
        for group in self.param_groups:
            for p in group['params']:
                #stateの初期化
                self.state[p] = dict(momentum=torch.zeros_like(p.data))
    def step(self, closure = None) -> None:
        """
        parameterのgradはbackwardメソッドで計算済みと考える。
        更新するパラメーターのt時点での値をW^{t}と表すと、
        W^{t+1} <- W^{t} - lr * W.grad.data + d_W^{t} * momentum
        d_W^{t} <- W^{t+1} - W^{t} =  - lr * W.grad.data + d_W^{t} * momentum
        の式を用いて更新する。
        """
        for group in self.param_groups:
            for p in group['params']:
                if p not in self.state:
                    self.state[p] = dict(momentum=torch.zeros_like(p.data))
                mom = self.state[p]['momentum']
                d_p = - group['lr'] * p.grad.data + group["momentum"] * mom
                p.data += d_p
                self.state[p]['momentum'] = d_p

一応、上手く動作するかを確認しておきましょう。
$y = 5x+1$上にデータが載っている場合の最適化を考えます。

#実験のためにデータを定義
l1 = nn.Linear(1,1)
l2 = nn.Linear(1,1)
x = torch.arange(10).view(-1,1).float()
y = 5*x+10
sgd_mom = MomentumSGD(l1.parameters(), lr = 0.01)
for i in range(175):
    y_pred_mom = l1(x)
    loss_mom = ((y_pred_mom-y)**2).std()
    loss_mom.backward()
    sgd_mom.step()
    sgd_mom.zero_grad()

    if (i+1) % 20 == 0:
        print(f"\nepoch: {i+1}, loss_mom: {loss_mom.item()}")
        print("\nこの時のMomentumのパラメーターの状態\n", sgd_mom.state)

以下のように表示されれば成功となります。

epoch: 20, loss_mom: 305.18133544921875
この時のMomentumのパラメーターの状態
 {Parameter containing:
tensor([[10.1745]], requires_grad=True): {'momentum': tensor([[0.3922]])}, 
Parameter containing:
tensor([-3.0462], requires_grad=True): {'momentum': tensor([-0.1554])}}


...

epoch: 160, loss_mom: 0.0005368478014133871

この時のMomentumのパラメーターの状態
 {Parameter containing:
tensor([[4.9840]], requires_grad=True): {'momentum': tensor([[-0.0166]])}, 
Parameter containing:
tensor([9.8483], requires_grad=True): {'momentum': tensor([0.0051])}}

普通のSGDも制作して同じく175 epochで最適化した時の収束の度合いを確かめてみました。67db315e-28e2-4d04-a35e-1691c458d800.png
やはりMomentumSGDの方が普通のSGDの最適化手法より早く収束していることが確かめられます。

まとめ

今回はtorch.optim.Optimizerを継承して自作の最適化手法を実装する方法を簡単に紹介してみました。もしこれ以上にOptimizerの実装を学ぶ場合は日本語記事はヒットしないので
「(最適化手法) from scratch in Pytorch」のように検索すると良いです。
他にもAdamを実装してみたというリポジトリや、SAM(Sharpness-Aware Minimization for Efficiently Improving Generalization)というOptimizerを実装してみたというリポジトリがあります。是非ご覧になってみてください。

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?