10
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Transformer(自己注意=Self-Attention)とポケモン

Posted at

これは「ポケモン風の短いテキスト」を学習データにして、Transformer(自己注意=Self-Attention)で次の文字を予測する “超小型言語モデル” を Colab 上で動かすデモコードです。目的は「アテンション機構がどこを見て次の文字を出しているか」を可視化することです。画像生成やゲームではなく、文字列(言葉)の生成・解析用です。

  1. 入力と出力(何を入れて何が出るか)
    ・入力:ポケモンっぽい文章(例:pikachu pika pika!ポケモン いけ! など)
    ・学習:文字単位で「次の1文字」を当てる訓練(language modeling)
    ・出力A:学習後、プロンプトから続きの文字列を自動生成(Generateボタン)
    ・出力B:生成/予測のときの注意行列(Attention Weights)をヒートマップ表示(Plot Attentionボタン)

  2. コードの中身(何の部品があるか)
    (1) 依存関係のインストール
    torch / numpy / matplotlib / ipywidgets / numba を入れます。

(2) PARAM_INIT
・モデルサイズ(d_model, n_heads, n_layers)
・学習条件(batch_size, lr, epochs)
・生成条件(temperature, top_k)
を一箇所に集約しています。

(3) データセット(コーパス)
pokemon_lines の文字列を連結してコーパスにします。
語彙は「登場する文字の集合」(英字、記号、改行、カタカナ等)です。

(4) モデル本体:TinyTransformerLM
・Embedding(文字ID→ベクトル)
・Positional Embedding(位置情報)
・TransformerBlockを複数段
・最後に線形層で次文字の確率(logits)を出します。
ここで MultiHeadSelfAttention を自前実装しており、last_attn に注意重み (B,H,T,T) を保存します。

(5) 学習
cross_entropy で「次の文字」を当てるように更新します。

(6) 生成
プロンプトを入れて、1文字ずつサンプリングで伸ばします(temperature / top_k あり)。

(7) 注意の可視化
最後のブロックの attention matrix を imshow で表示します。
縦=Query位置、横=Key位置なので、どの過去文字を参照しているかが見えます。

(8) UI
ipywidgets で
・Train
・Generate
・Plot Attention
・Corpus Stats
をボタンで実行できるようにしています。

  1. 何が「トランスフォーマー/アテンション」なのか
    ・Transformerの中核は Self-Attention で、各位置の表現が「過去のどの文字にどれだけ重みを置くか」を学習します。
    ・このコードはその重み(attention)を取り出して見せる構成になっています。
# Program Name: pokemon_transformer_attention_colab.py
# Creation Date: 20260122
# Purpose: Train a tiny Transformer (self-attention) on Pokémon-themed text and visualize attention.

# ============================================================
# 0) Install / Imports
# ============================================================
# NOTE: Colab usually has torch, but we pin versions for reproducibility.
# 注意: Colabにはtorchが入っていることが多いが、再現性のため版指定を行う。

try:
    !pip -q install "numpy==1.26.4" "matplotlib==3.8.4" "ipywidgets==8.1.2" "numba==0.59.1" "torch==2.2.2" "torchvision==0.17.2" "torchaudio==2.2.2"
except Exception as e:
    print("[WARN] pip install failed (may already be installed).", e)

import os
import math
import time
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from numba import njit
import ipywidgets as widgets
from IPython.display import display, clear_output

# ============================================================
# 1) PARAM_INIT (centralized settings)
# ============================================================
PARAM_INIT = {
    # Seed / 乱数seed
    "seed": 123,

    # Dataset / データ
    "block_size": 64,        # context length / 文脈長
    "min_text_repeat": 3,    # repeat corpus to increase samples / コーパス反復

    # Model / モデル
    "d_model": 128,
    "n_heads": 4,
    "n_layers": 2,
    "d_ff": 256,
    "dropout": 0.10,

    # Training / 学習
    "batch_size": 64,
    "lr": 2e-3,
    "weight_decay": 1e-4,
    "epochs": 3,
    "max_steps_per_epoch": 200,  # keep it small / 軽量化
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Generation / 生成
    "gen_max_new_chars": 200,
    "gen_temperature": 1.0,
    "gen_top_k": 30,

    # Output saving / 保存
    "out_dir": "/mnt/data",
}
os.makedirs(PARAM_INIT["out_dir"], exist_ok=True)

# ============================================================
# 2) Reproducibility / 再現性
# ============================================================
def set_seed(seed: int):
    """Inputs/Outputs/Process:
    Inputs: seed (int)
    Outputs: None
    Process: Fix random seeds for python/numpy/torch.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(PARAM_INIT["seed"])

# ============================================================
# 3) Pokémon-themed corpus (toy) / ポケモン風コーパス(簡易)
# ============================================================
# English + Katakana + simple phrases, char-level modeling.
# 英語+カタカナ+簡単フレーズ。文字単位モデル。
pokemon_lines = [
    "pikachu pika pika!\n",
    "bulbasaur bulba saur!\n",
    "charmander char!\n",
    "squirtle squirt squirt!\n",
    "eevee v~ee!\n",
    "mew mew...\n",
    "mewtwo ...\n",
    "gengar genga!\n",
    "snorlax zzz...\n",
    "lucario aura!\n",
    "ポケモン いけ!\n",
    "ピカチュウ でんきショック!\n",
    "フシギダネ はっぱカッター!\n",
    "ヒトカゲ ひのこ!\n",
    "ゼニガメ みずでっぽう!\n",
    "進化 する!\n",
    "きみに きめた!\n",
    "trainer battle start!\n",
    "critical hit!\n",
    "use thunderbolt!\n",
]

corpus = "".join(pokemon_lines) * PARAM_INIT["min_text_repeat"]

# Build vocab / 語彙(文字集合)
chars = sorted(list(set(corpus)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
vocab_size = len(chars)

def encode(s: str) -> np.ndarray:
    return np.array([stoi[c] for c in s], dtype=np.int64)

def decode(ix: np.ndarray) -> str:
    return "".join([itos[int(i)] for i in ix])

data = encode(corpus)

# ============================================================
# 4) Dataset / DataLoader
# ============================================================
class CharBlockDataset(Dataset):
    """Inputs/Outputs/Process:
    Inputs: token array, block_size
    Outputs: (x, y) where y is next-token target
    Process: Return random contiguous blocks for next-token prediction.
    """
    def __init__(self, tokens: np.ndarray, block_size: int):
        self.tokens = tokens
        self.block_size = block_size

    def __len__(self):
        return max(1, len(self.tokens) - self.block_size - 1)

    def __getitem__(self, idx):
        x = self.tokens[idx: idx + self.block_size]
        y = self.tokens[idx + 1: idx + 1 + self.block_size]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

dataset = CharBlockDataset(data, PARAM_INIT["block_size"])
loader = DataLoader(dataset, batch_size=PARAM_INIT["batch_size"], shuffle=True, drop_last=True)

# ============================================================
# 5) Transformer with attention weights (custom) / 注意重みを取り出すTransformer
# ============================================================
class MultiHeadSelfAttention(nn.Module):
    """Self-attention with weight export.
    注意: attn_weights を取り出せるように実装。
    """
    def __init__(self, d_model: int, n_heads: int, dropout: float):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)
        self.drop = nn.Dropout(dropout)

        self.last_attn = None  # (B, H, T, T)

    def forward(self, x, attn_mask=None):
        B, T, C = x.shape
        qkv = self.qkv(x)  # (B, T, 3C)
        q, k, v = qkv.chunk(3, dim=-1)

        # reshape to heads / ヘッド分割
        q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, T, Dh)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        # scaled dot-product attention / スケールド内積注意
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)  # (B, H, T, T)

        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, float("-inf"))

        attn = F.softmax(scores, dim=-1)
        attn = self.drop(attn)

        self.last_attn = attn.detach()  # store for visualization

        out = attn @ v  # (B, H, T, Dh)
        out = out.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, C)
        out = self.proj(out)
        out = self.drop(out)
        return out

class TransformerBlock(nn.Module):
    """Transformer block: LN -> MHA -> residual -> LN -> FFN -> residual."""
    def __init__(self, d_model, n_heads, d_ff, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x, attn_mask=None):
        x = x + self.attn(self.ln1(x), attn_mask=attn_mask)
        x = x + self.ff(self.ln2(x))
        return x

class TinyTransformerLM(nn.Module):
    """Character-level Transformer language model."""
    def __init__(self, vocab_size, block_size, d_model, n_heads, n_layers, d_ff, dropout):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.d_model = d_model

        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(block_size, d_model)
        self.drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # causal mask / 因果マスク(未来を見ない)
        mask = torch.tril(torch.ones(block_size, block_size)).unsqueeze(0).unsqueeze(0)  # (1,1,T,T)
        self.register_buffer("causal_mask", mask)

    def forward(self, idx):
        B, T = idx.shape
        assert T <= self.block_size

        pos = torch.arange(0, T, device=idx.device)
        x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
        x = self.drop(x)

        attn_mask = self.causal_mask[:, :, :T, :T]
        for blk in self.blocks:
            x = blk(x, attn_mask=attn_mask)

        x = self.ln_f(x)
        logits = self.head(x)  # (B,T,V)
        return logits

    def get_last_attention(self):
        # Return last block's attention weights if available
        if len(self.blocks) == 0:
            return None
        return self.blocks[-1].attn.last_attn  # (B,H,T,T)

model = TinyTransformerLM(
    vocab_size=vocab_size,
    block_size=PARAM_INIT["block_size"],
    d_model=PARAM_INIT["d_model"],
    n_heads=PARAM_INIT["n_heads"],
    n_layers=PARAM_INIT["n_layers"],
    d_ff=PARAM_INIT["d_ff"],
    dropout=PARAM_INIT["dropout"],
).to(PARAM_INIT["device"])

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=PARAM_INIT["lr"],
    weight_decay=PARAM_INIT["weight_decay"],
)

# ============================================================
# 6) Training / 学習
# ============================================================
def train_one_epoch(epoch_i: int):
    """Inputs/Outputs/Process:
    Inputs: epoch index
    Outputs: average loss
    Process: Train a small number of steps for demo.
    """
    model.train()
    losses = []
    step = 0
    for xb, yb in loader:
        xb = xb.to(PARAM_INIT["device"])
        yb = yb.to(PARAM_INIT["device"])

        logits = model(xb)
        loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1))

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        losses.append(float(loss.detach().cpu().item()))
        step += 1
        if step >= PARAM_INIT["max_steps_per_epoch"]:
            break
    return float(np.mean(losses)) if losses else float("nan")

# ============================================================
# 7) Generation / 生成
# ============================================================
@torch.no_grad()
def sample_next(logits, temperature=1.0, top_k=None):
    """Sample from logits with temperature and optional top-k.
    ロジットから温度付きサンプリング(top-k対応)。
    """
    logits = logits / max(1e-8, temperature)
    if top_k is not None and top_k > 0:
        v, ix = torch.topk(logits, min(top_k, logits.shape[-1]))
        mask = torch.full_like(logits, float("-inf"))
        mask.scatter_(dim=-1, index=ix, src=v)
        logits = mask
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

@torch.no_grad()
def generate(prompt: str, max_new_chars: int, temperature: float, top_k: int):
    model.eval()
    idx = torch.tensor(encode(prompt), dtype=torch.long, device=PARAM_INIT["device"])[None, :]
    for _ in range(max_new_chars):
        idx_cond = idx[:, -PARAM_INIT["block_size"]:]
        logits = model(idx_cond)
        last = logits[:, -1, :]
        nxt = sample_next(last, temperature=temperature, top_k=top_k)
        idx = torch.cat([idx, nxt], dim=1)
    out = idx[0].detach().cpu().numpy()
    return decode(out)

# ============================================================
# 8) Attention visualization / 注意行列の可視化
# ============================================================
def plot_attention(prompt: str, max_len: int = 64, head: int = 0, save: bool = False):
    """Plot attention matrix for a prompt (last block).
    promptの注意重み(最終ブロック)を表示。
    """
    model.eval()
    x = torch.tensor(encode(prompt), dtype=torch.long, device=PARAM_INIT["device"])[None, :]
    x = x[:, :min(x.shape[1], max_len)]
    T = x.shape[1]
    _ = model(x)
    attn = model.get_last_attention()
    if attn is None:
        print("No attention weights available.")
        return

    # attn: (B,H,T,T)
    A = attn[0, head, :T, :T].detach().cpu().numpy()

    plt.figure(figsize=(7, 6))
    plt.imshow(A, aspect="auto")
    plt.colorbar()
    plt.title("Attention Weights (Last Block)")
    plt.xlabel("Key position")
    plt.ylabel("Query position")

    # token labels (chars)
    toks = [itos[int(i)] for i in x[0].detach().cpu().numpy()]
    step = max(1, T // 16)  # avoid clutter
    ticks = list(range(0, T, step))
    plt.xticks(ticks, [toks[i] for i in ticks], rotation=90)
    plt.yticks(ticks, [toks[i] for i in ticks])

    plt.tight_layout()

    if save:
        ts = time.strftime("%Y%m%d_%H%M%S")
        fname_png = os.path.join(PARAM_INIT["out_dir"], f"attn_{ts}_head{head}.png")
        fname_pdf = os.path.join(PARAM_INIT["out_dir"], f"attn_{ts}_head{head}.pdf")
        plt.savefig(fname_png, dpi=200)
        plt.savefig(fname_pdf)
        print("Saved:", fname_png)
        print("Saved:", fname_pdf)

    plt.show()

# ============================================================
# 9) (Optional) Numba helper for fast histogram / おまけ:ヒストグラム高速化
# ============================================================
@njit
def token_histogram(tokens, vocab_size):
    h = np.zeros(vocab_size, dtype=np.int64)
    for i in range(tokens.size):
        t = tokens[i]
        if 0 <= t < vocab_size:
            h[t] += 1
    return h

def show_corpus_stats():
    h = token_histogram(data, vocab_size)
    top = np.argsort(-h)[:20]
    print("Top chars by frequency:")
    for i in top:
        ch = itos[int(i)]
        print(f"  '{ch}' : {int(h[i])}")

# ============================================================
# 10) UI (ipywidgets) / UI
# ============================================================
prompt_dropdown = widgets.Dropdown(
    options=[
        "pikachu pika pika!\n",
        "ポケモン いけ!\n",
        "ピカチュウ でんきショック!\n",
        "trainer battle start!\n",
        "use thunderbolt!\n",
        "進化 する!\n",
    ],
    value="pikachu pika pika!\n",
    description="Prompt:",
    layout=widgets.Layout(width="70%")
)

epochs_slider = widgets.IntSlider(
    value=PARAM_INIT["epochs"], min=1, max=20, step=1,
    description="Epochs:", continuous_update=False
)

steps_slider = widgets.IntSlider(
    value=PARAM_INIT["max_steps_per_epoch"], min=50, max=2000, step=50,
    description="Steps/ep:", continuous_update=False
)

temp_slider = widgets.FloatSlider(
    value=PARAM_INIT["gen_temperature"], min=0.3, max=2.0, step=0.1,
    description="Temp:", continuous_update=False
)

topk_slider = widgets.IntSlider(
    value=PARAM_INIT["gen_top_k"], min=0, max=100, step=5,
    description="Top-k:", continuous_update=False
)

genlen_slider = widgets.IntSlider(
    value=120, min=20, max=500, step=10,
    description="New chars:", continuous_update=False
)

head_slider = widgets.IntSlider(
    value=0, min=0, max=PARAM_INIT["n_heads"] - 1, step=1,
    description="Head:", continuous_update=False
)

save_checkbox = widgets.Checkbox(value=False, description="Save PNG/PDF")

btn_train = widgets.Button(description="Train", button_style="primary")
btn_generate = widgets.Button(description="Generate", button_style="success")
btn_attn = widgets.Button(description="Plot Attention", button_style="info")
btn_stats = widgets.Button(description="Corpus Stats", button_style="")

out = widgets.Output()

def on_train_clicked(_):
    with out:
        clear_output()
        try:
            PARAM_INIT["epochs"] = int(epochs_slider.value)
            PARAM_INIT["max_steps_per_epoch"] = int(steps_slider.value)
            print(f"Device: {PARAM_INIT['device']}")
            print(f"Vocab size: {vocab_size}, Block size: {PARAM_INIT['block_size']}")
            for ep in range(PARAM_INIT["epochs"]):
                loss = train_one_epoch(ep)
                print(f"Epoch {ep+1}/{PARAM_INIT['epochs']}  Loss: {loss:.4f}")
        except Exception as e:
            print("[ERROR] Training failed:", e)

def on_generate_clicked(_):
    with out:
        clear_output()
        try:
            prompt = prompt_dropdown.value
            text = generate(
                prompt=prompt,
                max_new_chars=int(genlen_slider.value),
                temperature=float(temp_slider.value),
                top_k=int(topk_slider.value) if int(topk_slider.value) > 0 else None
            )
            print("----- Generated Text -----")
            print(text)
        except Exception as e:
            print("[ERROR] Generation failed:", e)

def on_attn_clicked(_):
    with out:
        clear_output()
        try:
            prompt = prompt_dropdown.value
            plot_attention(
                prompt=prompt,
                max_len=PARAM_INIT["block_size"],
                head=int(head_slider.value),
                save=bool(save_checkbox.value)
            )
        except Exception as e:
            print("[ERROR] Attention plot failed:", e)

def on_stats_clicked(_):
    with out:
        clear_output()
        try:
            show_corpus_stats()
        except Exception as e:
            print("[ERROR] Stats failed:", e)

btn_train.on_click(on_train_clicked)
btn_generate.on_click(on_generate_clicked)
btn_attn.on_click(on_attn_clicked)
btn_stats.on_click(on_stats_clicked)

ui = widgets.VBox([
    widgets.HBox([prompt_dropdown]),
    widgets.HBox([epochs_slider, steps_slider]),
    widgets.HBox([temp_slider, topk_slider, genlen_slider]),
    widgets.HBox([head_slider, save_checkbox]),
    widgets.HBox([btn_train, btn_generate, btn_attn, btn_stats]),
    out
])

display(ui)

# ============================================================
# 11) Quick sanity run (optional) / 動作確認(任意)
# ============================================================
# You can click "Train" then "Generate" and "Plot Attention".
# まず Train → Generate → Plot Attention の順で確認。
10
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?