0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【コード解説編】論理ゲートで Transformer を超える実装 (PPL 4.73)

0
Posted at

【コード解説編】論理ゲートで Transformer を超える実装 (PPL 4.73)

論理ゲートだけで言語モデルを作って Transformer (PPL 4.86) を 0.13 上回った実装の解説です。
DLGN, HBA, 知識蒸留の 実コード を中心に、再現に必要な要点をまとめます。

物語 / 失敗譚は 物語編 を参照してください。

環境

  • Python 3.10+, PyTorch 2.1+
  • RTX 4060 8GB(CPU でも動作可、学習時間は伸びます)
git clone https://github.com/karumaru-kakikukekodoumei/boolean-attention.git
cd boolean-attention
pip install -r requirements.txt

Step 1. 微分可能な論理ゲート層 (DLGN)

論理ゲートは離散関数で勾配が流れません。2 入力ブール関数は $2^4 = 16$ 種類しかないという事実を使い、16 ゲートを softmax で混合 することで勾配を流します。

import torch
import torch.nn as nn
import torch.nn.functional as F

def all_gates(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """16 種類の論理ゲートを (..., 16) 次元で返す。a, b は [0,1] 連続値想定。"""
    return torch.stack([
        torch.zeros_like(a), a * b, a * (1 - b), a,
        (1 - a) * b, b, a + b - 2*a*b, a + b - a*b,
        1 - (a + b - a*b), 1 - (a + b - 2*a*b), 1 - b, a + (1-b) - a*(1-b),
        1 - a, (1-a) + b - (1-a)*b, 1 - a*b, torch.ones_like(a),
    ], dim=-1)


class DLGNLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, tau: float = 1.0):
        super().__init__()
        self.in_dim, self.out_dim = in_dim, out_dim
        self.tau = tau
        self.pair_a = nn.Parameter(torch.randn(out_dim, in_dim) * 0.5)
        self.pair_b = nn.Parameter(torch.randn(out_dim, in_dim) * 0.5)
        self.gate_logits = nn.Parameter(torch.randn(out_dim, 16) * 0.1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        wa = F.softmax(self.pair_a / self.tau, dim=-1)
        wb = F.softmax(self.pair_b / self.tau, dim=-1)
        a = x @ wa.T
        b = x @ wb.T
        gates = all_gates(a, b)
        alpha = F.softmax(self.gate_logits / self.tau, dim=-1)
        return (gates * alpha).sum(dim=-1)

学習が終わったら argmax(self.gate_logits) で 1 個に確定すれば、純粋なブーリアン回路に戻ります(hard collapse)。

python src/dlgn_charlm.py
結果 Soft PPL Hard PPL
DLGN flat (4 層) 11.83 15.16
Transformer (比較) 4.86

論理回路で言語学習はできた、ただし TF には届かず。次の設計へ。

Step 2. (失敗例) LoopedDLGN

DLGN を T 回繰り返す Universal Transformer 風設計。撃沈例として参考までに残します。

python src/looped_dlgn_charlm.py --max-iters=8
Soft PPL Hard PPL
v1 (PE なし) 11.05 754.31

ハードコラプス時に PPL が 754 まで暴騰。反復ごとの量子化誤差が

ε_total ≈ Σ_t ‖f_hard(x⁽ᵗ⁾) - f_soft(x⁽ᵗ⁾)‖

として深さ方向に蓄積するためです。反復系は Boolean と相性が悪い という構造的な学び。

Step 3. HBA — Boolean Router + float Value

Attention のルーターだけを Boolean 化、値集約は float のまま。

import torch.nn.utils as nn_utils

class BooleanAttentionLayer(nn.Module):
    def __init__(self, d: int, tau: float = 0.1):
        super().__init__()
        self.q = nn.Linear(d, d)
        self.k = nn.Linear(d, d)
        self.v = nn.Linear(d, d)
        # bilinear router (Lipschitz 制約に spectral norm)
        self.w_router = nn_utils.spectral_norm(nn.Linear(d, d, bias=False))
        self.tau = tau

    def forward(self, x: torch.Tensor, causal_mask: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(x), self.v(x)  # [B, T, d]
        # Q · W · K^T
        logits = Q @ self.w_router.weight @ K.transpose(-1, -2)  # [B, T, T]
        logits = logits.masked_fill(causal_mask, float("-inf"))

        if self.training:
            router = torch.tanh(logits / self.tau)  # 連続近似
        else:
            router = torch.sign(logits)  # 推論時離散

        attn = F.softmax(router / self.tau, dim=-1)
        return attn @ V  # V は float のまま

ポイント:

  • ルーターは離散値 (-1, +1) に確定
  • 値の集約は float なので 量子化誤差が深さ方向に伝播しない
  • spectral norm で router 重みのリプシッツ性を担保(発散防止)
python src/hba_charlm.py --epochs=60
HBA v1 Best PPL Final PPL
Ep12 / Ep60 5.40 9.75

TF (4.86) まで 0.54 差 まで肉薄。ただし過学習が課題。

Step 4. HBA v2 — 安定化 4 点セット

# 1. Best checkpoint
if val_ppl < best_ppl:
    best_ppl = val_ppl
    best_state = {k: v.clone() for k, v in model.state_dict().items()}
    best_epoch = ep
    bad_count = 0
else:
    bad_count += 1

# 2. Early stopping
if bad_count >= patience:
    print(f"early stop at ep {ep}")
    break

# 3. Hard threshold calibration
def calibrate_hard_threshold(model, val_loader, taus=(0.05, 0.08, 0.1, 0.15, 0.2)):
    best = (None, float("inf"))
    for tau in taus:
        model.set_inference_tau(tau)
        ppl = evaluate(model, val_loader)
        if ppl < best[1]:
            best = (tau, ppl)
    return best

# 4. warm_hold 温度スケジュール
def temperature_schedule(epoch: int) -> float:
    if epoch < 5:  return 1.0           # warm: 柔らかく
    if epoch < 15: return 0.5           # hold: 中間
    return max(0.1, 0.5 * 0.95**(epoch - 15))  # decay
python src/hba_charlm.py --epochs=40 --early-stop --calibrate
HBA v2 Soft PPL Hard PPL Train time
結果 5.32 6.54 4.7 min

LoopedDLGN の Hard PPL 754 と比べて 115 倍の改善

Step 5. 知識蒸留で TF 越え

教師 (TF) → 生徒 (HBA v2 構造) に蒸留。ハイブリッド損失で CE と KL を併用。

def distill_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = 0.3,
    T: float = 8.0,
) -> torch.Tensor:
    ce = F.cross_entropy(student_logits, targets)
    kl = F.kl_div(
        F.log_softmax(student_logits / T, dim=-1),
        F.softmax(teacher_logits / T, dim=-1),
        reduction="batchmean",
    )
    return alpha * ce + (1 - alpha) * (T * T) * kl
python src/hba_distill_charlm.py --epochs=30 --teacher-ckpt=teacher_tf.pt
Soft PPL
Teacher (TF) 4.86
Student (HBA distilled) 4.73
逆転幅 -0.13

論理回路ベースのモデルが Transformer を逆転。born-again networks (Furlanello et al. 2018) として知られる現象です。

ハマりどころ: 温度整合性のバグ

初期実装で訓練 eval と最終比較で温度 $\tau$ が違っていて、PPL が 4.71 vs 8.72 と乖離するバグに数日とられました。

Bad

# 訓練 eval は固定 tau=1.0、最終比較は final_tau=0.1 と別物
def evaluate(model, loader):
    model.set_inference_tau(1.0)
    ...

# 最終比較
model.set_inference_tau(0.1)  # ← 急に厳しい τ にする
final_ppl = evaluate(model, test_loader)

Good

# 訓練 eval は「現在のスケジューラ τ」で評価
def evaluate(model, loader, tau: float):
    model.set_inference_tau(tau)
    ...

# 最終比較は best epoch 時点の実 τ を逆引き
best_tau = temperature_schedule(best_epoch)
model.load_state_dict(best_state)
model.set_inference_tau(best_tau)
final_ppl = evaluate(model, test_loader, best_tau)

これで再現性のある PPL 4.73 が出るようになりました。

再現手順まとめ

# 1. ベースライン
python src/dlgn_charlm.py        # PPL 11.83

# 2. 失敗パス (任意)
python src/looped_dlgn_charlm.py # PPL 754 で爆死を体感

# 3. HBA v2
python src/hba_charlm.py --early-stop --calibrate  # PPL 5.32

# 4. 蒸留 (要 teacher checkpoint)
python src/train_teacher.py
python src/hba_distill_charlm.py  # PPL 4.73 → TF 越え

学習ログは results/、学習済み ChatHBA は checkpoints/ にあります。

応用先

HBA は 特化用途で実用性あり という結論:

  • Speculative decoding のドラフトモデル — 大きい教師モデルとの並用で軽量ルーティング
  • エッジ推論 — CPU/MCU で動く軽量 LM
  • 電力制約環境 — GPU を持たないシステム

まとめ

Step やったこと 結果
1 DLGN 層 16 ゲート softmax 混合で勾配 OK
2 DLGN flat PPL 11.83 (TF 4.86 未達)
3 LoopedDLGN PPL 754 で構造的に詰む
4 HBA v1 PPL 5.40 (TF まで 0.54 差)
5 HBA v2 PPL 5.32 / Hard 6.54
6 知識蒸留 PPL 4.73 (TF 4.86 → 0.13 上回る)

リンク

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?