これは「ポケモン風の短いテキスト」を学習データにして、Transformer(自己注意=Self-Attention)で次の文字を予測する “超小型言語モデル” を Colab 上で動かすデモコードです。目的は「アテンション機構がどこを見て次の文字を出しているか」を可視化することです。画像生成やゲームではなく、文字列(言葉)の生成・解析用です。
-
入力と出力(何を入れて何が出るか)
・入力:ポケモンっぽい文章(例:pikachu pika pika!、ポケモン いけ!など)
・学習:文字単位で「次の1文字」を当てる訓練(language modeling)
・出力A:学習後、プロンプトから続きの文字列を自動生成(Generateボタン)
・出力B:生成/予測のときの注意行列(Attention Weights)をヒートマップ表示(Plot Attentionボタン) -
コードの中身(何の部品があるか)
(1) 依存関係のインストール
torch / numpy / matplotlib / ipywidgets / numbaを入れます。
(2) PARAM_INIT
・モデルサイズ(d_model, n_heads, n_layers)
・学習条件(batch_size, lr, epochs)
・生成条件(temperature, top_k)
を一箇所に集約しています。
(3) データセット(コーパス)
pokemon_lines の文字列を連結してコーパスにします。
語彙は「登場する文字の集合」(英字、記号、改行、カタカナ等)です。
(4) モデル本体:TinyTransformerLM
・Embedding(文字ID→ベクトル)
・Positional Embedding(位置情報)
・TransformerBlockを複数段
・最後に線形層で次文字の確率(logits)を出します。
ここで MultiHeadSelfAttention を自前実装しており、last_attn に注意重み (B,H,T,T) を保存します。
(5) 学習
cross_entropy で「次の文字」を当てるように更新します。
(6) 生成
プロンプトを入れて、1文字ずつサンプリングで伸ばします(temperature / top_k あり)。
(7) 注意の可視化
最後のブロックの attention matrix を imshow で表示します。
縦=Query位置、横=Key位置なので、どの過去文字を参照しているかが見えます。
(8) UI
ipywidgets で
・Train
・Generate
・Plot Attention
・Corpus Stats
をボタンで実行できるようにしています。
- 何が「トランスフォーマー/アテンション」なのか
・Transformerの中核は Self-Attention で、各位置の表現が「過去のどの文字にどれだけ重みを置くか」を学習します。
・このコードはその重み(attention)を取り出して見せる構成になっています。
# Program Name: pokemon_transformer_attention_colab.py
# Creation Date: 20260122
# Purpose: Train a tiny Transformer (self-attention) on Pokémon-themed text and visualize attention.
# ============================================================
# 0) Install / Imports
# ============================================================
# NOTE: Colab usually has torch, but we pin versions for reproducibility.
# 注意: Colabにはtorchが入っていることが多いが、再現性のため版指定を行う。
try:
!pip -q install "numpy==1.26.4" "matplotlib==3.8.4" "ipywidgets==8.1.2" "numba==0.59.1" "torch==2.2.2" "torchvision==0.17.2" "torchaudio==2.2.2"
except Exception as e:
print("[WARN] pip install failed (may already be installed).", e)
import os
import math
import time
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from numba import njit
import ipywidgets as widgets
from IPython.display import display, clear_output
# ============================================================
# 1) PARAM_INIT (centralized settings)
# ============================================================
PARAM_INIT = {
# Seed / 乱数seed
"seed": 123,
# Dataset / データ
"block_size": 64, # context length / 文脈長
"min_text_repeat": 3, # repeat corpus to increase samples / コーパス反復
# Model / モデル
"d_model": 128,
"n_heads": 4,
"n_layers": 2,
"d_ff": 256,
"dropout": 0.10,
# Training / 学習
"batch_size": 64,
"lr": 2e-3,
"weight_decay": 1e-4,
"epochs": 3,
"max_steps_per_epoch": 200, # keep it small / 軽量化
"device": "cuda" if torch.cuda.is_available() else "cpu",
# Generation / 生成
"gen_max_new_chars": 200,
"gen_temperature": 1.0,
"gen_top_k": 30,
# Output saving / 保存
"out_dir": "/mnt/data",
}
os.makedirs(PARAM_INIT["out_dir"], exist_ok=True)
# ============================================================
# 2) Reproducibility / 再現性
# ============================================================
def set_seed(seed: int):
"""Inputs/Outputs/Process:
Inputs: seed (int)
Outputs: None
Process: Fix random seeds for python/numpy/torch.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(PARAM_INIT["seed"])
# ============================================================
# 3) Pokémon-themed corpus (toy) / ポケモン風コーパス(簡易)
# ============================================================
# English + Katakana + simple phrases, char-level modeling.
# 英語+カタカナ+簡単フレーズ。文字単位モデル。
pokemon_lines = [
"pikachu pika pika!\n",
"bulbasaur bulba saur!\n",
"charmander char!\n",
"squirtle squirt squirt!\n",
"eevee v~ee!\n",
"mew mew...\n",
"mewtwo ...\n",
"gengar genga!\n",
"snorlax zzz...\n",
"lucario aura!\n",
"ポケモン いけ!\n",
"ピカチュウ でんきショック!\n",
"フシギダネ はっぱカッター!\n",
"ヒトカゲ ひのこ!\n",
"ゼニガメ みずでっぽう!\n",
"進化 する!\n",
"きみに きめた!\n",
"trainer battle start!\n",
"critical hit!\n",
"use thunderbolt!\n",
]
corpus = "".join(pokemon_lines) * PARAM_INIT["min_text_repeat"]
# Build vocab / 語彙(文字集合)
chars = sorted(list(set(corpus)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
vocab_size = len(chars)
def encode(s: str) -> np.ndarray:
return np.array([stoi[c] for c in s], dtype=np.int64)
def decode(ix: np.ndarray) -> str:
return "".join([itos[int(i)] for i in ix])
data = encode(corpus)
# ============================================================
# 4) Dataset / DataLoader
# ============================================================
class CharBlockDataset(Dataset):
"""Inputs/Outputs/Process:
Inputs: token array, block_size
Outputs: (x, y) where y is next-token target
Process: Return random contiguous blocks for next-token prediction.
"""
def __init__(self, tokens: np.ndarray, block_size: int):
self.tokens = tokens
self.block_size = block_size
def __len__(self):
return max(1, len(self.tokens) - self.block_size - 1)
def __getitem__(self, idx):
x = self.tokens[idx: idx + self.block_size]
y = self.tokens[idx + 1: idx + 1 + self.block_size]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
dataset = CharBlockDataset(data, PARAM_INIT["block_size"])
loader = DataLoader(dataset, batch_size=PARAM_INIT["batch_size"], shuffle=True, drop_last=True)
# ============================================================
# 5) Transformer with attention weights (custom) / 注意重みを取り出すTransformer
# ============================================================
class MultiHeadSelfAttention(nn.Module):
"""Self-attention with weight export.
注意: attn_weights を取り出せるように実装。
"""
def __init__(self, d_model: int, n_heads: int, dropout: float):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
self.drop = nn.Dropout(dropout)
self.last_attn = None # (B, H, T, T)
def forward(self, x, attn_mask=None):
B, T, C = x.shape
qkv = self.qkv(x) # (B, T, 3C)
q, k, v = qkv.chunk(3, dim=-1)
# reshape to heads / ヘッド分割
q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2) # (B, H, T, Dh)
k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
# scaled dot-product attention / スケールド内積注意
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head) # (B, H, T, T)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
attn = self.drop(attn)
self.last_attn = attn.detach() # store for visualization
out = attn @ v # (B, H, T, Dh)
out = out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
out = self.proj(out)
out = self.drop(out)
return out
class TransformerBlock(nn.Module):
"""Transformer block: LN -> MHA -> residual -> LN -> FFN -> residual."""
def __init__(self, d_model, n_heads, d_ff, dropout):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x, attn_mask=None):
x = x + self.attn(self.ln1(x), attn_mask=attn_mask)
x = x + self.ff(self.ln2(x))
return x
class TinyTransformerLM(nn.Module):
"""Character-level Transformer language model."""
def __init__(self, vocab_size, block_size, d_model, n_heads, n_layers, d_ff, dropout):
super().__init__()
self.vocab_size = vocab_size
self.block_size = block_size
self.d_model = d_model
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(block_size, d_model)
self.drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
# causal mask / 因果マスク(未来を見ない)
mask = torch.tril(torch.ones(block_size, block_size)).unsqueeze(0).unsqueeze(0) # (1,1,T,T)
self.register_buffer("causal_mask", mask)
def forward(self, idx):
B, T = idx.shape
assert T <= self.block_size
pos = torch.arange(0, T, device=idx.device)
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
x = self.drop(x)
attn_mask = self.causal_mask[:, :, :T, :T]
for blk in self.blocks:
x = blk(x, attn_mask=attn_mask)
x = self.ln_f(x)
logits = self.head(x) # (B,T,V)
return logits
def get_last_attention(self):
# Return last block's attention weights if available
if len(self.blocks) == 0:
return None
return self.blocks[-1].attn.last_attn # (B,H,T,T)
model = TinyTransformerLM(
vocab_size=vocab_size,
block_size=PARAM_INIT["block_size"],
d_model=PARAM_INIT["d_model"],
n_heads=PARAM_INIT["n_heads"],
n_layers=PARAM_INIT["n_layers"],
d_ff=PARAM_INIT["d_ff"],
dropout=PARAM_INIT["dropout"],
).to(PARAM_INIT["device"])
optimizer = torch.optim.AdamW(
model.parameters(),
lr=PARAM_INIT["lr"],
weight_decay=PARAM_INIT["weight_decay"],
)
# ============================================================
# 6) Training / 学習
# ============================================================
def train_one_epoch(epoch_i: int):
"""Inputs/Outputs/Process:
Inputs: epoch index
Outputs: average loss
Process: Train a small number of steps for demo.
"""
model.train()
losses = []
step = 0
for xb, yb in loader:
xb = xb.to(PARAM_INIT["device"])
yb = yb.to(PARAM_INIT["device"])
logits = model(xb)
loss = F.cross_entropy(logits.view(-1, vocab_size), yb.view(-1))
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
losses.append(float(loss.detach().cpu().item()))
step += 1
if step >= PARAM_INIT["max_steps_per_epoch"]:
break
return float(np.mean(losses)) if losses else float("nan")
# ============================================================
# 7) Generation / 生成
# ============================================================
@torch.no_grad()
def sample_next(logits, temperature=1.0, top_k=None):
"""Sample from logits with temperature and optional top-k.
ロジットから温度付きサンプリング(top-k対応)。
"""
logits = logits / max(1e-8, temperature)
if top_k is not None and top_k > 0:
v, ix = torch.topk(logits, min(top_k, logits.shape[-1]))
mask = torch.full_like(logits, float("-inf"))
mask.scatter_(dim=-1, index=ix, src=v)
logits = mask
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
@torch.no_grad()
def generate(prompt: str, max_new_chars: int, temperature: float, top_k: int):
model.eval()
idx = torch.tensor(encode(prompt), dtype=torch.long, device=PARAM_INIT["device"])[None, :]
for _ in range(max_new_chars):
idx_cond = idx[:, -PARAM_INIT["block_size"]:]
logits = model(idx_cond)
last = logits[:, -1, :]
nxt = sample_next(last, temperature=temperature, top_k=top_k)
idx = torch.cat([idx, nxt], dim=1)
out = idx[0].detach().cpu().numpy()
return decode(out)
# ============================================================
# 8) Attention visualization / 注意行列の可視化
# ============================================================
def plot_attention(prompt: str, max_len: int = 64, head: int = 0, save: bool = False):
"""Plot attention matrix for a prompt (last block).
promptの注意重み(最終ブロック)を表示。
"""
model.eval()
x = torch.tensor(encode(prompt), dtype=torch.long, device=PARAM_INIT["device"])[None, :]
x = x[:, :min(x.shape[1], max_len)]
T = x.shape[1]
_ = model(x)
attn = model.get_last_attention()
if attn is None:
print("No attention weights available.")
return
# attn: (B,H,T,T)
A = attn[0, head, :T, :T].detach().cpu().numpy()
plt.figure(figsize=(7, 6))
plt.imshow(A, aspect="auto")
plt.colorbar()
plt.title("Attention Weights (Last Block)")
plt.xlabel("Key position")
plt.ylabel("Query position")
# token labels (chars)
toks = [itos[int(i)] for i in x[0].detach().cpu().numpy()]
step = max(1, T // 16) # avoid clutter
ticks = list(range(0, T, step))
plt.xticks(ticks, [toks[i] for i in ticks], rotation=90)
plt.yticks(ticks, [toks[i] for i in ticks])
plt.tight_layout()
if save:
ts = time.strftime("%Y%m%d_%H%M%S")
fname_png = os.path.join(PARAM_INIT["out_dir"], f"attn_{ts}_head{head}.png")
fname_pdf = os.path.join(PARAM_INIT["out_dir"], f"attn_{ts}_head{head}.pdf")
plt.savefig(fname_png, dpi=200)
plt.savefig(fname_pdf)
print("Saved:", fname_png)
print("Saved:", fname_pdf)
plt.show()
# ============================================================
# 9) (Optional) Numba helper for fast histogram / おまけ:ヒストグラム高速化
# ============================================================
@njit
def token_histogram(tokens, vocab_size):
h = np.zeros(vocab_size, dtype=np.int64)
for i in range(tokens.size):
t = tokens[i]
if 0 <= t < vocab_size:
h[t] += 1
return h
def show_corpus_stats():
h = token_histogram(data, vocab_size)
top = np.argsort(-h)[:20]
print("Top chars by frequency:")
for i in top:
ch = itos[int(i)]
print(f" '{ch}' : {int(h[i])}")
# ============================================================
# 10) UI (ipywidgets) / UI
# ============================================================
prompt_dropdown = widgets.Dropdown(
options=[
"pikachu pika pika!\n",
"ポケモン いけ!\n",
"ピカチュウ でんきショック!\n",
"trainer battle start!\n",
"use thunderbolt!\n",
"進化 する!\n",
],
value="pikachu pika pika!\n",
description="Prompt:",
layout=widgets.Layout(width="70%")
)
epochs_slider = widgets.IntSlider(
value=PARAM_INIT["epochs"], min=1, max=20, step=1,
description="Epochs:", continuous_update=False
)
steps_slider = widgets.IntSlider(
value=PARAM_INIT["max_steps_per_epoch"], min=50, max=2000, step=50,
description="Steps/ep:", continuous_update=False
)
temp_slider = widgets.FloatSlider(
value=PARAM_INIT["gen_temperature"], min=0.3, max=2.0, step=0.1,
description="Temp:", continuous_update=False
)
topk_slider = widgets.IntSlider(
value=PARAM_INIT["gen_top_k"], min=0, max=100, step=5,
description="Top-k:", continuous_update=False
)
genlen_slider = widgets.IntSlider(
value=120, min=20, max=500, step=10,
description="New chars:", continuous_update=False
)
head_slider = widgets.IntSlider(
value=0, min=0, max=PARAM_INIT["n_heads"] - 1, step=1,
description="Head:", continuous_update=False
)
save_checkbox = widgets.Checkbox(value=False, description="Save PNG/PDF")
btn_train = widgets.Button(description="Train", button_style="primary")
btn_generate = widgets.Button(description="Generate", button_style="success")
btn_attn = widgets.Button(description="Plot Attention", button_style="info")
btn_stats = widgets.Button(description="Corpus Stats", button_style="")
out = widgets.Output()
def on_train_clicked(_):
with out:
clear_output()
try:
PARAM_INIT["epochs"] = int(epochs_slider.value)
PARAM_INIT["max_steps_per_epoch"] = int(steps_slider.value)
print(f"Device: {PARAM_INIT['device']}")
print(f"Vocab size: {vocab_size}, Block size: {PARAM_INIT['block_size']}")
for ep in range(PARAM_INIT["epochs"]):
loss = train_one_epoch(ep)
print(f"Epoch {ep+1}/{PARAM_INIT['epochs']} Loss: {loss:.4f}")
except Exception as e:
print("[ERROR] Training failed:", e)
def on_generate_clicked(_):
with out:
clear_output()
try:
prompt = prompt_dropdown.value
text = generate(
prompt=prompt,
max_new_chars=int(genlen_slider.value),
temperature=float(temp_slider.value),
top_k=int(topk_slider.value) if int(topk_slider.value) > 0 else None
)
print("----- Generated Text -----")
print(text)
except Exception as e:
print("[ERROR] Generation failed:", e)
def on_attn_clicked(_):
with out:
clear_output()
try:
prompt = prompt_dropdown.value
plot_attention(
prompt=prompt,
max_len=PARAM_INIT["block_size"],
head=int(head_slider.value),
save=bool(save_checkbox.value)
)
except Exception as e:
print("[ERROR] Attention plot failed:", e)
def on_stats_clicked(_):
with out:
clear_output()
try:
show_corpus_stats()
except Exception as e:
print("[ERROR] Stats failed:", e)
btn_train.on_click(on_train_clicked)
btn_generate.on_click(on_generate_clicked)
btn_attn.on_click(on_attn_clicked)
btn_stats.on_click(on_stats_clicked)
ui = widgets.VBox([
widgets.HBox([prompt_dropdown]),
widgets.HBox([epochs_slider, steps_slider]),
widgets.HBox([temp_slider, topk_slider, genlen_slider]),
widgets.HBox([head_slider, save_checkbox]),
widgets.HBox([btn_train, btn_generate, btn_attn, btn_stats]),
out
])
display(ui)
# ============================================================
# 11) Quick sanity run (optional) / 動作確認(任意)
# ============================================================
# You can click "Train" then "Generate" and "Plot Attention".
# まず Train → Generate → Plot Attention の順で確認。