はじめに
ゲージ理論と生成敵対的ネットワーク(GAN)、敵対的模倣学習(GAIL)、RLHFて似てるなと思い、生成AIさんに資料を作文してもらいました。資料を放流するのでよかったら読んで感想を頂けると嬉しいです。
資料
GAN学習の安定性の鍵は「勢い」かもしれない。本記事では、GAN学習において、なぜAdamはSGDより安定するのかを解明する。
GANは、その驚異的な画像生成能力とは裏腹に、学習の不安定さという致命的な問題を抱えていた。この不安定性の根本原因を、物理学の慣性と減衰の概念を用いて解き明かした。
理論とは厳密な計量は違うものの、近似理論を構築した。さらに、学習プロセスを散逸系として捉えることで、最適化手法における慣性項の有無が、GANの安定性と収束性に寄与することを示した。
さらに、敵対的学習をゲージ理論から定式化することで、Natural Gradient with Momentumを導出した。
実装と実行結果
# -*- coding: utf-8 -*-
"""
GAN training with 4 optimizers: SGD / Adam / Natural Gradient / Natural Gradient + Momentum
+ Curvature (F-hat) visualization + antisymmetric Jacobian (projected) + (optional) game-K.
Author: Hideki Yoshida (refined by Copilot)
"""
import math, time, copy, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
# ----------------------------- Utilities -------------------------------------
def set_seed(seed=42):
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"
def flatten_params(model):
return torch.cat([p.reshape(-1) for p in model.parameters()])
def save_state(model):
return copy.deepcopy(model.state_dict())
def load_state(model, state):
model.load_state_dict(state)
# --------------------------- Data: 1D Gaussian -------------------------------
real_mean, real_std = 8.0, 1.25
def get_real_data(bs, device):
return torch.randn(bs,1,device=device)*real_std + real_mean
# --------------------------- Models ------------------------------------------
class Generator(nn.Module):
def __init__(self, latent_dim=16, hidden_dim=64, out_dim=1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, out_dim)
)
def forward(self, z):
return self.net(z)
class Discriminator(nn.Module):
"""Outputs logits. Use BCEWithLogitsLoss for stability."""
def __init__(self, in_dim=1, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim), nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1) # logits
)
def forward(self, x):
return self.net(x)
# ------------------- Natural Gradient helpers (Gauss–Newton) -----------------
def generator_jacobian_dense(G, z_batch):
"""
Dense per-sample Jacobian of generator outputs wrt parameters.
Returns J: (N, P). This version is robust but O(N*P).
"""
plist = [p for p in G.parameters()]
P_total = sum(p.numel() for p in plist)
N = z_batch.shape[0]
J = torch.zeros(N, P_total, device=z_batch.device, dtype=plist[0].dtype)
# sample-wise vector-Jacobian
for i in range(N):
G.zero_grad(set_to_none=True)
out = G(z_batch[i:i+1]).sum() # scalar
grads = torch.autograd.grad(out, G.parameters(), retain_graph=True, create_graph=False)
J[i, :] = torch.cat([g.reshape(-1) for g in grads])
return J
def solve_spd(A, b, method="cholesky"):
"""Solve A v = b for SPD A. Prefer Cholesky; fallback to CG."""
if method == "cholesky":
try:
L = torch.linalg.cholesky(A)
v = torch.cholesky_solve(b.unsqueeze(1), L).squeeze(1)
return v
except RuntimeError:
pass
# CG fallback
v, info = torch.linalg.cg(A, b, atol=1e-5, maxiter=1000)
return v
# ---------------------- Curvature A_hat and F_hat -----------------------------
def param_directions(model, mode="first_two"):
"""
Construct two directions ei, ej (list of tensors with same shapes as params).
Default: use first two entries of the first parameter tensor.
"""
ei = [torch.zeros_like(p) for p in model.parameters()]
ej = [torch.zeros_like(p) for p in model.parameters()]
# pick first weight tensor
for k, p in enumerate(model.parameters()):
flat = torch.zeros_like(p).view(-1)
if flat.numel() >= 1:
flat[0] = 1.0; ei[k] = flat.view_as(p)
if flat.numel() >= 2:
flat[1] = 1.0; ej[k] = flat.view_as(p)
break
return ei, ej
def add_scaled_params_(model, dir_tensors, alpha):
with torch.no_grad():
for p, d in zip(model.parameters(), dir_tensors):
p.add_(alpha * d)
def discriminator_logit(D, x): # x: (N, d)
return D(x)
def phi_logit(D, x):
"""Potential surrogate: logit(D(x)) (already logits)."""
return discriminator_logit(D, x)
def grad_x_phi(D, x):
r"""
∇_x φ(x) を計算する。ここでは φ(x) = logit(D(x)) を近似ポテンシャルとする。
x: (N, d) with requires_grad=True
"""
assert x.requires_grad, "Xgrid は requires_grad=True である必要があります。"
with torch.enable_grad(): # 上位が no_grad でもここは有効化
s = D(x).sum() # logits の合計(スカラー)
g, = torch.autograd.grad(s, x, create_graph=False, retain_graph=False, allow_unused=False)
return g # (N, d)
def A_hat(G, D, Xgrid, dir_tensors, Delta):
r"""
A_hat(theta)(u) ≈ ∇_x [ (φ(θ+Δu) - φ(θ)) / Δ ] を格子点 Xgrid で評価。
注意:autograd.grad を使うため no_grad は使わない。
"""
state0 = save_state(G)
# θ -> θ + Δ u
add_scaled_params_(G, dir_tensors, +Delta)
g_pos = grad_x_phi(D, Xgrid) # ∇_x φ(θ+Δu)
# 戻す
load_state(G, state0)
g_base = grad_x_phi(D, Xgrid) # ∇_x φ(θ)
return (g_pos - g_base) / Delta # (N, d)
def F_hat(G, D, Xgrid, ei, ej, Delta):
r"""
F_hat = A_{θ+Δei}(ej) - A_θ(ej) - A_{θ+Δej}(ei) + A_θ(ei).
no_grad で囲むと autograd.grad が無効化されるため使用しない。
"""
state0 = save_state(G)
# A(theta+Δei)(ej)
add_scaled_params_(G, ei, +Delta)
A_ej_pos = A_hat(G, D, Xgrid, ej, Delta)
load_state(G, state0)
# A(theta)(ej)
A_ej_base = A_hat(G, D, Xgrid, ej, Delta)
# A(theta+Δej)(ei)
add_scaled_params_(G, ej, +Delta)
A_ei_pos = A_hat(G, D, Xgrid, ei, Delta)
load_state(G, state0)
# A(theta)(ei)
A_ei_base = A_hat(G, D, Xgrid, ei, Delta)
return A_ej_pos - A_ej_base - A_ei_pos + A_ei_base # (N, d)
def projected_H_K_on_plane(G, D, z_batch, ei, ej, h=1e-3, device='cpu'):
r"""
2次元平面(ei, ej)上で L_G の 2x2 ヘッセ行列 H を中心差分で近似し、K=0.5(H-H^T) を返す。
ここでは勾配追跡は不要なので item() を使ってOK。
"""
base_state = save_state(G)
def L_at(u, v):
load_state(G, base_state)
add_scaled_params_(G, ei, u); add_scaled_params_(G, ej, v)
logits = D(G(z_batch))
target = torch.ones_like(logits, device=device, dtype=logits.dtype)
return nn.BCEWithLogitsLoss()(logits, target).item()
f00 = L_at(0,0)
f10 = L_at(h,0); f_10 = L_at(-h,0)
f01 = L_at(0,h); f0_1 = L_at(0,-h)
f11 = L_at(h,h)
Huu = (f10 - 2*f00 + f_10)/(h*h)
Hvv = (f01 - 2*f00 + f0_1)/(h*h)
Huv = (f11 - f10 - f01 + f00)/(h*h)
Hvu = Huv # 理想的には対称
H = torch.tensor([[Huu, Huv],[Hvu, Hvv]], dtype=torch.float32, device=device)
K = 0.5*(H - H.T)
load_state(G, base_state)
return H, K
# (Optional) projected game-K on 2D (one dir in θ, one in φ) -------------------
def projected_game_K(G, D, z_batch, e_theta, e_phi, h=1e-3, device='cpu'):
"""
Build 2D vector field F(α,β) = [ <∇_θ L_G, e_theta>, <∇_φ L_D, e_phi> ] and get its Jacobian.
The antisymmetric part approximates game rotation on this plane.
"""
base_theta = save_state(G); base_phi = save_state(D)
bce = nn.BCEWithLogitsLoss()
def grad_proj_G(u, v):
load_state(G, base_theta); load_state(D, base_phi)
add_scaled_params_(G, e_theta, u); add_scaled_params_(D, e_phi, v)
z = z_batch
logits = D(G(z))
target = torch.ones_like(logits, device=device, dtype=logits.dtype)
lossG = bce(logits, target)
g_list = torch.autograd.grad(lossG, list(G.parameters()), retain_graph=False, create_graph=False)
g_flat = torch.cat([g.reshape(-1) for g in g_list])
# project onto e_theta
e_flat = torch.cat([d.reshape(-1) for d in e_theta])
return torch.dot(g_flat, e_flat).item()
def grad_proj_D(u, v):
load_state(G, base_theta); load_state(D, base_phi)
add_scaled_params_(G, e_theta, u); add_scaled_params_(D, e_phi, v)
# standard discriminator loss (real->1, fake->0)
real = get_real_data(z_batch.shape[0], device)
fake = G(z_batch).detach()
lossD = bce(D(real), torch.ones_like(D(real))) + bce(D(fake), torch.zeros_like(D(fake)))
d_list = torch.autograd.grad(lossD, list(D.parameters()), retain_graph=False, create_graph=False)
d_flat = torch.cat([g.reshape(-1) for g in d_list])
e_flat = torch.cat([d.reshape(-1) for d in e_phi])
return torch.dot(d_flat, e_flat).item()
# Jacobian by central differences on 2D
f1 = lambda u,v: grad_proj_G(u,v)
f2 = lambda u,v: grad_proj_D(u,v)
# partial derivatives
d11 = (f1(h,0)-f1(-h,0))/(2*h) # ∂/∂u of f1
d12 = (f1(0,h)-f1(0,-h))/(2*h) # ∂/∂v of f1
d21 = (f2(h,0)-f2(-h,0))/(2*h) # ∂/∂u of f2
d22 = (f2(0,h)-f2(0,-h))/(2*h) # ∂/∂v of f2
J = torch.tensor([[d11,d12],[d21,d22]], dtype=torch.float32, device=device)
K = 0.5*(J - J.T)
load_state(G, base_theta); load_state(D, base_phi)
return J, K
# -------------------------- Training loop (4 variants) ------------------------
def train_g(G_opt="ng",
epochs=3000, bs=256, latent_dim=16, hidden_dim=64,
lr=1e-3, damping=1e-3, momentum_mu=0.9,
jacobian_bs=128, solver="cholesky",
viz_every=1000, # visualize F-hat & projected K
device=None, debug=False):
device = device or get_device()
G = Generator(latent_dim, hidden_dim, 1).to(device)
D = Discriminator(1, hidden_dim).to(device)
optD = optim.Adam(D.parameters(), lr=lr, betas=(0.5,0.999))
if G_opt == "sgd":
optG = optim.SGD(G.parameters(), lr=10*lr)
elif G_opt == "adam":
optG = optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
else:
optG = None
velocity = {n: torch.zeros_like(p) for n,p in G.named_parameters()} if G_opt=="ng_momentum" else None
bce_logits = nn.BCEWithLogitsLoss()
loss_hist = []
for ep in range(1, epochs+1):
# -------------------- D step --------------------
D.train(); G.train()
real = get_real_data(bs, device)
z = torch.randn(bs, latent_dim, device=device)
fake_detached = G(z).detach()
logits_real = D(real)
logits_fake = D(fake_detached)
lossD = bce_logits(logits_real, torch.ones_like(logits_real, dtype=logits_real.dtype, device=device)) + \
bce_logits(logits_fake, torch.zeros_like(logits_fake, dtype=logits_fake.dtype, device=device))
optD.zero_grad(); lossD.backward(); optD.step()
# -------------------- G step --------------------
if G_opt in ("sgd","adam"):
z = torch.randn(bs, latent_dim, device=device)
logits_forG = D(G(z))
target = torch.ones_like(logits_forG, device=device, dtype=logits_forG.dtype)
lossG = bce_logits(logits_forG, target)
G.zero_grad(set_to_none=True); lossG.backward(); optG.step()
else: # NG / NG + momentum (use autograd.grad to get G-only grads)
z = torch.randn(bs, latent_dim, device=device)
g_out = G(z)
logits_forG = D(g_out)
target = torch.ones_like(logits_forG, device=device, dtype=logits_forG.dtype)
lossG = bce_logits(logits_forG, target)
if debug:
print(f"[{G_opt}] sanity: logits requires_grad={logits_forG.requires_grad}, loss requires_grad={lossG.requires_grad}")
# G の勾配のみを取得(backwardは使わない)
G.zero_grad(set_to_none=True)
g_list = torch.autograd.grad(lossG, list(G.parameters()),
retain_graph=False, create_graph=False, allow_unused=False)
grad_flat = torch.cat([g.reshape(-1) for g in g_list])
# Build Gauss–Newton metric
zJ = torch.randn(min(bs, jacobian_bs), latent_dim, device=device)
J = generator_jacobian_dense(G, zJ) # (N, P)
N = max(1, J.shape[0])
F = (J.T @ J) / N
P = F.shape[0]
A = F + damping * torch.eye(P, device=F.device, dtype=F.dtype)
v = solve_spd(A, grad_flat, method=solver) # (P,)
with torch.no_grad():
off = 0
if G_opt == "ng":
for p in G.parameters():
n = p.numel()
step = v[off:off+n].view_as(p)
p.add_(-lr * step)
off += n
else: # ng_momentum
for (name, p) in G.named_parameters():
n = p.numel()
step = v[off:off+n].view_as(p)
velocity[name].mul_(momentum_mu).add_(step, alpha=-lr)
p.add_(velocity[name])
off += n
# D の勾配は破棄(更新しない)
for p in D.parameters():
p.grad = None
loss_hist.append(lossG.item())
if ep % 500 == 0:
print(f"[{G_opt}] ep={ep:4d} L_D={lossD.item():.3f} L_G={lossG.item():.3f}")
# ------------- Visualization (optional) -------------
if viz_every and (ep % viz_every == 0):
xs = torch.linspace(-5, 15, steps=400, device=device).view(-1,1).requires_grad_(True)
ei, ej = param_directions(G)
Delta = 1e-3
F = F_hat(G, D, xs, ei, ej, Delta) # (N,1)
F_norm = F.norm(dim=1).detach().cpu().numpy()
z_small = torch.randn(128, latent_dim, device=device)
H2, K2 = projected_H_K_on_plane(G, D, z_small, ei, ej, h=1e-3, device=device)
print(f"[{G_opt}] projected antisymmetric K (should be ~0 for scalar L_G):\n{K2.cpu().numpy()}")
plt.figure(figsize=(10,4))
plt.subplot(1,2,1); plt.plot(xs.detach().cpu().numpy(), F_norm)
plt.title(r"$\|\widehat{F}(e_i,e_j)(x)\|$"); plt.grid(True)
plt.subplot(1,2,2); plt.imshow(K2.cpu().numpy(), cmap="bwr",
vmin=-abs(K2).max()+1e-12, vmax=abs(K2).max()+1e-12)
plt.colorbar(); plt.title("Projected antisymmetric K")
plt.tight_layout(); plt.show()
return loss_hist, G, D
# ------------------------------- Run all --------------------------------------
if __name__ == "__main__":
set_seed(42)
device = get_device()
# Train 4 optimizers
hist_sgd, G_sgd, D_sgd = train_g("sgd", epochs=3000, lr=1e-3, device=device, viz_every=0)
hist_adam, G_adam, D_adam = train_g("adam", epochs=3000, lr=1e-3, device=device, viz_every=0)
hist_ng, G_ng, D_ng = train_g("ng", epochs=3000, lr=1e-3, damping=1e-3, solver="cholesky", device=device, viz_every=1000)
hist_ngm, G_ngm, D_ngm = train_g("ng_momentum", epochs=3000, lr=1e-3, damping=1e-3, momentum_mu=0.85, device=device, viz_every=1000)
# Plot losses
plt.figure(figsize=(12,5))
plt.plot(hist_sgd, label="SGD")
plt.plot(hist_adam, label="Adam")
plt.plot(hist_ng, label="Natural Gradient")
plt.plot(hist_ngm, label="Natural Gradient + Momentum", linewidth=2.2)
plt.legend(); plt.grid(True); plt.xlabel("epoch"); plt.ylabel("G-loss"); plt.title("Generator loss comparison")
plt.show()
# Distribution comparison
with torch.no_grad():
z = torch.randn(2000, 16, device=device)
def dist(Gm):
return Gm(z).view(-1).cpu().detach().numpy()
real = get_real_data(2000, device).view(-1).cpu().detach().numpy()
models = [G_sgd , G_adam , G_ng , G_ngm]
titles = ["SGD","Adam","NaturalGrad","NaturalGrad+Momentum"]
fig, axes = plt.subplots(1,4, figsize=(22,4), sharey=True)
for ax, M, t in zip(axes, models, titles):
gen = dist(M)
ax.hist(real, bins=60, density=True, alpha=0.6, label="Real")
ax.hist(gen, bins=60, density=True, alpha=0.6, label="Gen")
ax.set_title(t); ax.grid(True)
axes[0].legend(); plt.show()
資料において、理論はワッサースタイン-2空間のOtto計量に立脚する。しかし、有限次元ニューラルネットワークによる最適化であるため、本研究の主張を厳密に近似理論として構築し、慣性と減衰のバランスを、理想的なフローの運動特性を抽象化し、ユークリッド空間上の最適化を通じ、その妥当性を検証できることを示した。
GANの安定性の鍵は「適切な慣性」と言えるかもしれない。グラフが示すように、GANの学習においては、単に勾配(力)に忠実であること(慣性なし/高減衰)よりも、適切な「慣性」を持たせて勢いを制御すること(高慣性/低減衰)が、不安定性を克服し、安定した均衡点へと向かうために極めて効果的であった。これは、学習プロセスが単なる勾配降下ではなく、「慣性」と「減衰」の釣り合いが重要な散逸系のダイナミクスとして捉えられるべきであるという、論文の理論的視点と一致している。