Coupled Adam とは、Adam の 2 次モーメントが埋め込みベクトルを偏在させる仮説に基づき、成分ごとに全トークンで共通の 2 次モーメントを用いるように変更したものです。2025年の ACL 2025(Long Papers)に採択された論文 Better Embeddings with Coupled Adam で提案されました。
以下に文献概要とトイコードを記します。
参考文献
Felix Stollenwerk and Tobias Stollenwerk, Better Embeddings with Coupled Adam, Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), 2025.
文献概要
- Adam (私の検算記事) は、1, 2次モーメントを用いる勾配降下法です。Adam のネットワークの重み $w_t$ の更新式は以下です (ただし訓練初期の補正と安定化の $\varepsilon$ を省略)。
- $w_t = w_{t-1} - \alpha \cdot ( m_t / \sqrt{v_t})$ ($\alpha$ は学習率)
- $m_t$ : 1 次モーメント = 勾配の移動平均 = 平均的な勾配方向
- $v_t$ : 2 次モーメント = 勾配の 2 乗の移動平均 = 平均的な勾配の大きさ (Adam の更新式は、勾配の大きさを正規化する形になっている)
- $w_t = w_{t-1} - \alpha \cdot ( m_t / \sqrt{v_t})$ ($\alpha$ は学習率)
- しかし、Adam による最適化ではトークンの埋め込みベクトルが偏在し、平均ベクトルが原点からずれることが観測されていました。
- Stollenwerk らは、偏在の原因が、2 次モーメントによる正規化にあると考えました。
- 正解トークン予測タスクを考えると、基本的な SGD であれば、全トークン (1 つの正解トークンでその他が不正解トークン) の埋め込みベクトルの平均はずれないはずだからです (Figure 1)。
- Stollenwerk らは、Adam の 2 次モーメントの期待値はそのトークンの出現割合に比例することを実験的に確かめました。つまり、Adam は、低頻度トークンの小さいはずの更新幅を大きくし、高頻度トークンの大きいはずの更新幅を小さくします。
- そこで Stollenwerk らは、成分ごとに全トークンで共通の 2 次モーメントを用いる Coupled Adam を提案しました。トークン $i$ の埋め込みベクトルの第 $k$ 成分の更新式は以下です (ただし訓練初期の補正と安定化の $\varepsilon$ を省略)。
- Adam : $w_t^{(i,k)} = w_{t-1}^{(i,k)} - \alpha \cdot ( m_t^{(i,k)} / \sqrt{v_t^{(i,k)}})$
- Coupled Adam : $w_t^{(i,k)} = w_{t-1}^{(i,k)} - \alpha \cdot ( m_t^{(i,k)} / \sqrt{ (1/V)\sum_{j=1}^{V} v_t^{(j,k)}})$
- 検証実験では、自然言語モデルを事前学習し、下流タスク (理科問題回答タスクや次文選択タスクなど) を few-shot で回答させたときの正解率を計測しました。なお、Coupled Adam を用いたのは埋め込み層だけで、他の層は Adam のままとしました。
- 結果として、埋め込みベクトルの偏りは一貫して小さくなりました。
- ただし、正解率の面では、小規模・長めの学習では改善しやすい一方、モデルサイズに対して訓練トークン数が少ない条件では悪化する場合もありました。
- fine-tuning は検証実験に含めていません。
トイコード
以下は、次トークンのスコア分布を出力する出力層に対する勾配と更新量を次元ごとに足し合わせることで、SGD / Adam / Coupled Adam で各トークンの重みベクトルの中心の動き方がどう異なるかを確かめるトイコードです。
ここでは簡単のため、入力トークンの埋め込み層ではなく、各トークンのスコアを計算する出力層の重みベクトルをみています。
import torch
class Model(torch.nn.Module):
# 3 次元の隠れ状態 h から次トークンのスコア分布を出力する
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 5, bias=False) # 語彙数 5
def forward(self, h):
return self.linear(h)
class CoupledAdam(torch.optim.Optimizer):
def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8):
defaults = {'lr': lr, 'betas': betas, 'eps': eps}
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['step'] += 1
step = state['step']
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
m_hat = exp_avg / (1 - beta1 ** step)
v_hat = exp_avg_sq / (1 - beta2 ** step)
if p.ndim == 2:
v_hat = v_hat.mean(dim=0, keepdim=True).expand_as(p)
p.addcdiv_(m_hat, v_hat.sqrt().add(eps), value=-lr)
def run(name, make_optimizer):
torch.manual_seed(0)
model = Model()
h = torch.randn(4, 3) # 4 サンプルの 3 次元の隠れ状態
y = torch.tensor([0, 1, 2, 3]) # 正解トークンID
before = model.linear.weight.detach().clone()
optimizer = make_optimizer(model.parameters())
loss = torch.nn.functional.cross_entropy(model(h), y)
loss.backward()
grad_sum = model.linear.weight.grad.sum(dim=0)
optimizer.step()
update = model.linear.weight.detach() - before
update_sum = update.sum(dim=0)
print(f'===== {name} =====')
fmt = lambda x: '[' + ', '.join(f'{v:+.4e}' for v in x.tolist()) + ']'
print('grad sum: ', fmt(grad_sum))
print('update sum:', fmt(update_sum))
print()
def main():
run('SGD', lambda params: torch.optim.SGD(params, lr=0.1))
run('Adam', lambda params: torch.optim.Adam(params, lr=0.1))
run('Coupled Adam', lambda params: CoupledAdam(params, lr=0.1))
if __name__ == '__main__':
main()
実行すると以下のようになります。
===== SGD =====
grad sum: [+2.9802e-08, -1.4901e-08, -2.9802e-08]
update sum: [+3.3528e-08, +0.0000e+00, -1.4901e-08]
===== Adam =====
grad sum: [+2.9802e-08, -1.4901e-08, -2.9802e-08]
update sum: [-1.0000e-01, -1.0000e-01, -1.0000e-01]
===== Coupled Adam =====
grad sum: [+2.9802e-08, -1.4901e-08, -2.9802e-08]
update sum: [-1.4901e-08, +2.9802e-08, +2.9802e-08]