0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorch の Adam の検算

0
Last updated at Posted at 2026-06-13

torch.optim.Adam が、Adam の更新式通りにパラメータを更新してくれているかを簡単な例で確認するスクリプトを書きました。

入出力が 1 次元の線形モデル y(x) = w * x + b を用意し、MSE 損失の勾配を計算します。その後、Adam の1, 2 次モーメント推定値およびパラメータ更新幅を自力計算し、optimizer.step() 後のパラメータが自力計算結果に一致することを確認しています。

script.py
import torch
from torch import tensor, allclose


def main():
    # 損失関数とモデルの用意
    criterion = torch.nn.MSELoss()  # L(y, y_true) = (y - y_true)^2
    model = torch.nn.Linear(1, 1)  # y(x) = w * x + b
    # ちなみにこのときあるサンプルの損失の w, b についての勾配は以下
    # ∂L(y, y_true)/∂y = 2 * (y - y_true)
    # ∂L(y, y_true)/∂w = 2 * x * (y - y_true)
    # ∂L(y, y_true)/∂b = 2 * (y - y_true)

    # モデルに初期パラメータをセット
    w, b = -0.5, 1.0  # 適当
    with torch.no_grad():
        model.weight.copy_(tensor([[w]]))
        model.bias.copy_(tensor([b]))

    # Adam オプティマイザの用意
    # https://docs.pytorch.org/docs/2.12/generated/torch.optim.Adam.html
    lr, betas, eps = 0.1, (0.9, 0.99), 1e-08
    optimizer = torch.optim.Adam(model.parameters(), lr, betas, eps)

    # Adam の勾配の 1, 2 次モーメント推定値の更新式
    def update(m, v, g):
        return betas[0] * m + (1 - betas[0]) * g, betas[1] * v + (1 - betas[1]) * g * g

    # Adam のパラメータ更新幅 (方向付き)
    def get_step(m, v, i_step):
        m_hat = m / (1 - betas[0] ** i_step)  # 初期値 0 で初期の推定値が小さいので補正
        v_hat = v / (1 - betas[1] ** i_step)  # 初期値 0 で初期の推定値が小さいので補正
        return -1. * lr * m_hat / (v_hat ** 0.5 + eps)

    def _print(i_step, w, b, loss, model, m_w, v_w, m_b, v_b, step_w, step_b):
        print(
            f'========== {i_step} 回目の更新結果 ==========\n'
            '* 更新前の w と b は以下でした\n'
            f'  * w = {w:.4f}\n'
            f'  * b = {b:.4f}\n'
            '* 更新前の損失は以下でした\n'
            f'  * L = {loss.item():.4f}\n'
            '* 更新前の損失の w と b に関する勾配は以下でした\n'
            f'  * grad L_w = {model.weight.grad.item():.4f}'
            f' --> m_w = {m_w:.4f}, v_w = {v_w:.4f}\n'
            f'  * grad L_b = {model.bias.grad.item():.4f}'
            f' --> m_b = {m_b:.4f}, v_b = {v_b:.4f}\n'
            '* 算出された w と b の更新幅は以下でした\n'
            f'  * step_w = {step_w:.4f}\n'
            f'  * step_b = {step_b:.4f}\n'
        )

    # ============================================================
    # 1 回目の更新
    # ============================================================
    # 2 個のダミーサンプルの予測値を計算し損失をとる
    x, y_true = tensor([[1.0], [2.0]]), tensor([[1.0], [2.0]])
    optimizer.zero_grad()
    y = model(x)
    loss = criterion(y, y_true)
    assert allclose(y, tensor([[0.5], [0.0]]))
    assert allclose(loss, tensor([2.125]))  # ∵ (0.25 + 4) --[平均]-> 2.125

    # 損失の勾配をセット
    loss.backward()
    assert allclose(model.weight.grad, tensor([[-4.5]]))
    assert allclose(model.bias.grad, tensor([-2.5]))
    # ∵ Σ∂L/∂w = 2 * 1 * (0.5 - 1) + 2 * 2 * (0 - 2) = -9 --[平均]-> -4.5
    # ∵ Σ∂L/∂b = 2 * (0.5 - 1) + 2 * (0 - 2) = -5 --[平均]-> -2.5

    # (自力計算) 勾配の1, 2 次モーメント推定値の更新・それを用いた更新幅計算
    m_w, v_w, m_b, v_b = 0., 0., 0., 0.  # モーメントの初期値はゼロ
    i_step = 1
    m_w, v_w = update(m_w, v_w, -4.5)
    m_b, v_b = update(m_b, v_b, -2.5)
    step_w = get_step(m_w, v_w, i_step)
    step_b = get_step(m_b, v_b, i_step)

    # 勾配更新して更新結果が自力計算結果と合致していることを確認
    optimizer.step()
    assert allclose(model.weight.data, tensor([[w + step_w]]))
    assert allclose(model.bias.data, tensor([b + step_b]))

    # オプティマイザに記録されたモーメント推定値も自力計算結果と合致していることを確認
    state_w = optimizer.state[model.weight]
    assert allclose(state_w['step'], tensor(1.))
    assert allclose(state_w['exp_avg'], tensor([[m_w]]))
    assert allclose(state_w['exp_avg_sq'], tensor([[v_w]]))
    state_b = optimizer.state[model.bias]
    assert allclose(state_b['step'], tensor(1.))
    assert allclose(state_b['exp_avg'], tensor([[m_b]]))
    assert allclose(state_b['exp_avg_sq'], tensor([[v_b]]))

    _print(i_step, w, b, loss, model, m_w, v_w, m_b, v_b, step_w, step_b)
    w += step_w
    b += step_b

    # ============================================================
    # 2 回目以降
    # ============================================================
    for _ in range(100):
        optimizer.zero_grad()
        y = model(x)
        loss = criterion(y, y_true)
        loss.backward()  # 損失の勾配をセット

        # (自力計算) 勾配の1, 2 次モーメント推定値の更新・それを用いた更新幅計算
        i_step += 1
        m_w, v_w = update(m_w, v_w, model.weight.grad.item())
        m_b, v_b = update(m_b, v_b, model.bias.grad.item())
        step_w = get_step(m_w, v_w, i_step)
        step_b = get_step(m_b, v_b, i_step)

        # 勾配更新して更新結果が自力計算結果と合致していることを確認
        optimizer.step()
        assert allclose(model.weight.data, tensor([[w + step_w]]))
        assert allclose(model.bias.data, tensor([b + step_b]))

        # オプティマイザに記録されたモーメント推定値も自力計算結果と合致していることを確認
        state_w = optimizer.state[model.weight]
        assert allclose(state_w['step'], tensor(float(i_step)))
        assert allclose(state_w['exp_avg'], tensor([[m_w]]))
        assert allclose(state_w['exp_avg_sq'], tensor([[v_w]]))
        state_b = optimizer.state[model.bias]
        assert allclose(state_b['step'], tensor(float(i_step)))
        assert allclose(state_b['exp_avg'], tensor([[m_b]]))
        assert allclose(state_b['exp_avg_sq'], tensor([[v_b]]))

        _print(i_step, w, b, loss, model, m_w, v_w, m_b, v_b, step_w, step_b)
        w += step_w
        b += step_b


if __name__ == '__main__':
    main()

上記のスクリプトを実行すると assert に抵触することなく実行できます。
また、以下のようにデバッグプリントされます。

========== 1 回目の更新結果 ==========
* 更新前の w と b は以下でした
  * w = -0.5000
  * b = 1.0000
* 更新前の損失は以下でした
  * L = 2.1250
* 更新前の損失の w と b に関する勾配は以下でした
  * grad L_w = -4.5000 --> m_w = -0.4500, v_w = 0.2025
  * grad L_b = -2.5000 --> m_b = -0.2500, v_b = 0.0625
* 算出された w と b の更新幅は以下でした
  * step_w = 0.1000
  * step_b = 0.1000

========== 2 回目の更新結果 ==========
* 更新前の w と b は以下でした
  * w = -0.4000
  * b = 1.1000
* 更新前の損失は以下でした
  * L = 1.4900
* 更新前の損失の w と b に関する勾配は以下でした
  * grad L_w = -3.7000 --> m_w = -0.7750, v_w = 0.3374
  * grad L_b = -2.0000 --> m_b = -0.4250, v_b = 0.1019
* 算出された w と b の更新幅は以下でした
  * step_w = 0.0991
  * step_b = 0.0989

========== 3 回目の更新結果 ==========
* 更新前の w と b は以下でした
  * w = -0.3009
  * b = 1.1989
* 更新前の損失は以下でした
  * L = 0.9894
* 更新前の損失の w と b に関する勾配は以下でした
  * grad L_w = -2.9081 --> m_w = -0.9883, v_w = 0.4186
  * grad L_b = -1.5051 --> m_b = -0.5330, v_b = 0.1235
* 算出された w と b の更新幅は以下でした
  * step_w = 0.0971
  * step_b = 0.0964

(中略)

========== 101 回目の更新結果 ==========
* 更新前の w と b は以下でした
  * w = 0.9091
  * b = 0.1520
* 更新前の損失は以下でした
  * L = 0.0023
* 更新前の損失の w と b に関する勾配は以下でした
  * grad L_w = 0.0012 --> m_w = -0.0277, v_w = 0.2849
  * grad L_b = 0.0311 --> m_b = 0.0301, v_b = 0.1097
* 算出された w と b の更新幅は以下でした
  * step_w = 0.0041
  * step_b = -0.0073
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?