AIの要素技術について記述します。
参考
マルチヘッドアテンション (Multi-head Attention) [Transformerの部品]【CVMLエキスパートガイド】
理解
2017年に登場したモデル("Attention is All You Need" 論文)
サンプルプログラム
transformer_seq2seq.py
Q/K/V(クエリ・キー・バリュー)を線形変換で作成 → ヘッドに分割 → Scaled Dot-Product Attention
EncoderLayer: Self-Attention + FFN
DecoderLayer: Causal Self-Attention + Cross-Attention(Source-Target Attention) + FFN
位置エンコーディング(正弦波)
パディングマスクとデコーダの因果マスクを両立して適用
学習が文字レベル(char)なので、文字の断片がそれっぽく並ぶ出力になる
文字単位ではなく、BPE/SentencePiece などのサブワードトークンに変えると、単語や句のまとまりを学びやすくなり、“英文っぽさ”が大幅に向上する。ただし、単語単位にすると語彙は数万語になり、実装が複雑で、計算量も多くなる
パラメータ調整
d_model = 128
n_heads = 4
d_ff = 256
num_encoder_layers = 3
num_decoder_layers = 3
dropout_rate = 0.1
batch_size = 32
epochs = 8
d_model = 384
n_heads = 8
d_ff = 4 * d_model # 1536
num_encoder_layers = 4
num_decoder_layers = 4
dropout_rate = 0.2
batch_size = 16 # PC 16GB メモリで 16、d_model=256, enc/dec=3/3 ならば 32
epochs = 100
d_model = 512
n_heads = 8
d_ff = 4 * d_model # 2048
num_encoder_layers = 6
num_decoder_layers = 6
dropout_rate = 0.3
batch_size = 32
epochs = 120
サンプルプログラム
pip install torch
# -*- coding: utf-8 -*-
"""
tiny-shakespeare を用いた標準 Transformer(Encoder-Decoder)
- 文字レベル(char-level)
- 各デコーダ層で Self-Attn → Cross-Attn → FFN(標準構成)
- サンプリング生成(temperature / top-k / repetition_penalty)
"""
import os, math, urllib.request, pathlib, random
import torch
import torch.nn as nn
import torch.nn.functional as F
# =========================
# ハイパーパラメータ
# =========================
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42
random.seed(seed); torch.manual_seed(seed)
data_dir = "./data"
tiny_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
tiny_path = os.path.join(data_dir, "tinyshakespeare.txt")
# モデル&学習(CPUでも回せる設定。epochs を伸ばす。一段強化(CPU/軽GPU)。おおよそのパラメータ数 ~18M)
d_model = 384
n_heads = 6
d_ff = 4 * d_model # 1536
max_len = 512
src_len = 128
tgt_len = 128
batch_size = 16
epochs = 100 # 学習を伸ばす
lr = 3e-4
bos_token_str = "<BOS>"
# 層数(可変)
num_encoder_layers = 4
num_decoder_layers = 4
dropout_rate = 0.2
# 生成(サンプリング)設定のデフォルト
gen_steps = 300
gen_temperature = 0.7
gen_top_k = 20
gen_repetition_penalty = 1.3
# =========================
# データ取得
# =========================
def ensure_tiny_shakespeare():
pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True)
if not os.path.exists(tiny_path):
print("Downloading tiny-shakespeare...")
urllib.request.urlretrieve(tiny_url, tiny_path)
print("Saved:", tiny_path)
with torch.no_grad():
ensure_tiny_shakespeare()
with open(tiny_path, "r", encoding="utf-8") as f:
text = f.read()
# 文字辞書(+ BOS)
chars = sorted(list(set(text)))
if bos_token_str not in chars:
chars.append(bos_token_str)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
vocab_size = len(chars)
bos_id = stoi[bos_token_str]
def encode(s: str):
return torch.tensor([stoi[c] for c in s], dtype=torch.long)
def decode(ids):
return "".join(itos[int(i)] for i in ids)
# =========================
# ミニ・ローダ
# =========================
def get_batch(batch_size, src_len, tgt_len):
B = batch_size
total_len = src_len + tgt_len
xs = torch.zeros((B, src_len), dtype=torch.long)
ys = torch.zeros((B, tgt_len), dtype=torch.long)
for b in range(B):
i = random.randint(0, len(text) - total_len - 1)
chunk = text[i : i + total_len]
xs[b] = encode(chunk[:src_len])
ys[b] = encode(chunk[src_len:])
return xs.to(device), ys.to(device)
# =========================
# 位置エンコーディング
# =========================
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 4096):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
T = x.size(1)
return x + self.pe[:T, :]
# =========================
# Multi-Head Attention
# =========================
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, bias: bool = True):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=bias)
self.W_k = nn.Linear(d_model, d_model, bias=bias)
self.W_v = nn.Linear(d_model, d_model, bias=bias)
self.W_o = nn.Linear(d_model, d_model, bias=bias)
def _split(self, x):
B, T, D = x.shape
return x.view(B, T, self.n_heads, self.d_head).transpose(1, 2) # (B,H,T,Dh)
def _merge(self, x):
B, H, T, Dh = x.shape
return x.transpose(1, 2).contiguous().view(B, T, H * Dh)
def forward(self, q_inp, k_inp, v_inp, attn_mask=None, return_qkv=False):
Q = self._split(self.W_q(q_inp))
K = self._split(self.W_k(k_inp))
V = self._split(self.W_v(v_inp))
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head) # (B,H,Tq,Tk)
if attn_mask is not None:
scores = scores + attn_mask
attn = F.softmax(scores, dim=-1)
ctx = attn @ V
out = self.W_o(self._merge(ctx))
if return_qkv:
return out, Q, K, V
return out
# =========================
# エンコーダ層(Self-Attn → FFN)
# =========================
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff=256, dropout=0.1):
super().__init__()
self.self_mha = MultiHeadAttention(d_model, n_heads)
self.ln1 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
self.ln2 = nn.LayerNorm(d_model)
self.do = nn.Dropout(dropout)
def forward(self, x, attn_mask=None, return_qkvs=False):
attn_out, Q, K, V = self.self_mha(x, x, x, attn_mask, return_qkv=True)
x = self.ln1(x + self.do(attn_out))
ff_out = self.ff(x)
x = self.ln2(x + self.do(ff_out))
if return_qkvs:
return x, (Q, K, V)
return x
# =========================
# デコーダ層(Self-Attn → Cross-Attn → FFN)
# =========================
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff=256, dropout=0.1):
super().__init__()
self.self_mha = MultiHeadAttention(d_model, n_heads)
self.ln1 = nn.LayerNorm(d_model)
self.cross_mha = MultiHeadAttention(d_model, n_heads)
self.ln2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
self.ln3 = nn.LayerNorm(d_model)
self.do = nn.Dropout(dropout)
@staticmethod
def causal_mask(T_q, T_k, device):
mask = torch.full((T_q, T_k), float("-inf"), device=device)
mask = torch.triu(mask, diagonal=1) # 上三角を -inf(未来を隠す)
return mask.unsqueeze(0).unsqueeze(0) # (1,1,Tq,Tk)
def forward(self, x, memory, device, memory_pad_mask=None, return_self_q=False):
# 1) Masked Self-Attn
B, T_tgt, _ = x.shape
self_mask = self.causal_mask(T_tgt, T_tgt, device)
self_out, Q_self, K_self, V_self = self.self_mha(x, x, x, self_mask, return_qkv=True)
x = self.ln1(x + self.do(self_out))
# 2) Cross-Attn
cross_mask = None
if memory_pad_mask is not None:
m = memory_pad_mask.unsqueeze(1).unsqueeze(2) # (B,1,1,T_src)
cross_mask = m * torch.finfo(x.dtype).min
cross_out, Q_cross, K_cross, V_cross = self.cross_mha(x, memory, memory, cross_mask, return_qkv=True)
x = self.ln2(x + self.do(cross_out))
# 3) FFN
ff_out = self.ff(x)
x = self.ln3(x + self.do(ff_out))
if return_self_q:
return x, Q_self
return x
# =========================
# スタック
# =========================
class EncoderStack(nn.Module):
def __init__(self, d_model, n_heads, d_ff, num_layers, max_len=4096, dropout=0.1):
super().__init__()
self.pe = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
def forward(self, src_embed, src_pad_mask=None, return_last_qkv=True):
x = self.pe(src_embed)
attn_mask = None
if src_pad_mask is not None:
mask = src_pad_mask.unsqueeze(1).unsqueeze(2) # (B,1,1,T_src)
attn_mask = mask * torch.finfo(x.dtype).min
last_Q = last_K = last_V = None
for i, layer in enumerate(self.layers):
if return_last_qkv and (i == len(self.layers) - 1):
x, (last_Q, last_K, last_V) = layer(x, attn_mask, return_qkvs=True)
else:
x = layer(x, attn_mask, return_qkvs=False)
return x, last_K, last_V # encoder_key, encoder_value(最終層)
class DecoderStack(nn.Module):
def __init__(self, d_model, n_heads, d_ff, num_layers, max_len=4096, dropout=0.1):
super().__init__()
self.pe = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
def forward(self, tgt_embed, memory, device, memory_pad_mask=None, return_last_self_q=True):
x = self.pe(tgt_embed)
last_Q_self = None
for i, layer in enumerate(self.layers):
if return_last_self_q and (i == len(self.layers) - 1):
x, last_Q_self = layer(x, memory, device, memory_pad_mask, return_self_q=True)
else:
x = layer(x, memory, device, memory_pad_mask, return_self_q=False)
return x, last_Q_self # decoder_query(最終層 Self-Attn の Q)
# =========================
# 標準 Transformer(ED)
# =========================
class StandardEncoderDecoder(nn.Module):
def __init__(
self,
vocab_size,
d_model=128,
n_heads=4,
d_ff=256,
max_len=4096,
num_encoder_layers=3,
num_decoder_layers=3,
dropout=0.1,
):
super().__init__()
self.d_model = d_model
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.encoder = EncoderStack(d_model, n_heads, d_ff, num_encoder_layers, max_len=max_len, dropout=dropout)
self.decoder = DecoderStack(d_model, n_heads, d_ff, num_decoder_layers, max_len=max_len, dropout=dropout)
self.out_proj = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, src_ids, tgt_ids, src_pad_mask=None):
src_embed = self.tok_emb(src_ids) * math.sqrt(self.d_model)
tgt_embed = self.tok_emb(tgt_ids) * math.sqrt(self.d_model)
memory, encoder_key, encoder_value = self.encoder(src_embed, src_pad_mask, return_last_qkv=True)
dec_out, decoder_query = self.decoder(tgt_embed, memory, src_ids.device,
memory_pad_mask=src_pad_mask,
return_last_self_q=True)
logits = self.out_proj(dec_out) # (B,T_tgt,V)
aux = {
"encoder_key": encoder_key,
"encoder_value": encoder_value,
"decoder_query": decoder_query,
"memory": memory,
"dec_out": dec_out,
}
return logits, aux
# =========================
# 学習ユーティリティ
# =========================
def sequence_ce_loss(logits, target_ids, label_smoothing: float = 0.0):
"""
logits: (B,T,V), target_ids: (B,T)
次トークン予測:logits[:, :-1] vs target[:, 1:]
"""
B, T, V = logits.shape
logits_shift = logits[:, :-1, :].contiguous().view(-1, V)
target_shift = target_ids[:, 1:].contiguous().view(-1)
return F.cross_entropy(logits_shift, target_shift, label_smoothing=label_smoothing)
# =========================
# サンプリング生成(temperature / top-k / repetition penalty)
# =========================
@torch.no_grad()
def sample_generate(model,
src_prompt: str,
steps: int = 300,
temperature: float = 0.9,
top_k: int | None = 40,
repetition_penalty: float = 1.1):
"""
- greedy ではなくサンプリングで生成
- repetition_penalty で直近までの出力トークンの確率を少し抑制
"""
model.eval()
src_ids = encode(src_prompt).unsqueeze(0).to(device) # (1, Ts)
dec_ids = torch.tensor([[bos_id]], dtype=torch.long, device=device) # (1,1)
src_pad_mask = None
for _ in range(steps):
logits, _ = model(src_ids, dec_ids, src_pad_mask=src_pad_mask) # (1, Td, V)
next_logits = logits[:, -1, :]
# 温度
next_logits = next_logits / max(temperature, 1e-6)
# repetition penalty(簡易):出したトークンに対して logit を弱める
if repetition_penalty and repetition_penalty > 1.0:
for tok in dec_ids[0].tolist():
next_logits[0, tok] /= repetition_penalty
# top-k フィルタ
if top_k is not None and top_k > 0:
k = min(top_k, next_logits.size(-1))
v, _ = torch.topk(next_logits, k=k)
thresh = v[:, -1].unsqueeze(-1)
next_logits = torch.where(next_logits < thresh, torch.full_like(next_logits, float("-inf")), next_logits)
probs = torch.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # サンプリング
dec_ids = torch.cat([dec_ids, next_id], dim=1)
gen_ids = dec_ids[0, 1:].tolist()
return decode(gen_ids)
# ======================================
# セーブ/ロード(重みのみ)
# ======================================
CHECKPOINT_DIR = "./checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "ed_standard_weights.pth")
def save_weights(model):
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
torch.save(model.state_dict(), CHECKPOINT_PATH)
print(f"[save] model weights saved to {CHECKPOINT_PATH}")
def load_weights(model):
state = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(state)
model.eval()
print(f"[load] model weights loaded from {CHECKPOINT_PATH}")
# =========================
# メイン
# =========================
def main():
print(f"Device: {device}")
print(f"Vocab size: {vocab_size}, data length: {len(text)}")
model = StandardEncoderDecoder(
vocab_size,
d_model=d_model,
n_heads=n_heads,
d_ff=d_ff,
max_len=max_len,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout_rate,
).to(device)
mode = "train" # "train" が学習して保存。"load" にすると保存済みモデルをロード
if mode == "train":
# ============================
# 学習モード:最後に1回だけ保存
# ============================
steps_per_epoch = max(1, len(text) // (batch_size * (src_len + tgt_len)))
PRINTS_PER_EPOCH = 1 # 1なら各エポックに1回、2なら各エポックに2回表示
print_every = max(1, steps_per_epoch // PRINTS_PER_EPOCH)
LOG_EVERY_EPOCHS = 10 # 10エポックに1回表示。1なら毎エポック
global_step = 0
opt = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=epochs * steps_per_epoch, eta_min=1e-5
)
for epoch in range(1, epochs + 1):
model.train()
running = 0.0
for it in range(steps_per_epoch):
src_ids, tgt_ids = get_batch(batch_size, src_len, tgt_len)
# 教師強制:デコーダ入力は <BOS> + tgt[:-1]
bos = torch.full((batch_size, 1), bos_id, dtype=torch.long, device=device)
dec_in = torch.cat([bos, tgt_ids[:, :-1]], dim=1)
logits, _ = model(src_ids, dec_in, src_pad_mask=None)
loss = sequence_ce_loss(logits, tgt_ids, label_smoothing=0.1)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
scheduler.step()
running += loss.item()
global_step += 1
if ((it + 1) % print_every == 0) or ((it + 1) == steps_per_epoch):
# 直近ウィンドウの実ステップ数(エポック末の端数にも対応)
window = (it + 1) % print_every
if window == 0:
window = print_every
total_step = (epoch - 1) * steps_per_epoch + (it + 1)
if (epoch % LOG_EVERY_EPOCHS == 0) or (epoch == 1) or (epoch == epochs):
print(f"epoch {epoch} step {total_step}: loss {running / window:.4f}")
# 各エポックでのサンプル生成(サンプリング)
#with torch.no_grad():
# demo_src = "ROMEO:\n"
# out = sample_generate(
# model,
# demo_src,
# steps=200,
# temperature=gen_temperature,
# top_k=gen_top_k,
# repetition_penalty=gen_repetition_penalty,
# )
# print("=" * 60)
# print("SRC PROMPT:\n" + demo_src)
# print("-" * 60)
# print("SAMPLE GEN (sampling):\n" + out)
# print("=" * 60)
running = 0.0
save_weights(model)
# 最終サンプル
src_prompt = "JULIET:\n"
sample = sample_generate(
model,
src_prompt,
steps=gen_steps,
temperature=gen_temperature,
top_k=gen_top_k,
repetition_penalty=gen_repetition_penalty,
)
print("\n=== FINAL SAMPLE (sampling) ===")
print("SRC PROMPT:\n" + src_prompt)
print("GENERATED:\n" + sample)
else:
# ============================
# 読み込みモード:重みをロードして生成
# ============================
load_weights(model)
demo_src = "ROMEO:\n"
out = sample_generate(
model,
demo_src,
steps=200,
temperature=gen_temperature,
top_k=gen_top_k,
repetition_penalty=gen_repetition_penalty,
)
print("=" * 60)
print("SRC PROMPT:\n" + demo_src)
print("-" * 60)
print("SAMPLE GEN (sampling):\n" + out)
print("=" * 60)
if __name__ == "__main__":
main()
Device: cpu
Vocab size: 66, data length: 1115394
epoch 1 step 272: loss 3.2069
epoch 10 step 2720: loss 2.7036
epoch 20 step 5440: loss 2.5672
epoch 30 step 8160: loss 2.4917
epoch 40 step 10880: loss 2.4434
epoch 50 step 13600: loss 2.3909
epoch 60 step 16320: loss 2.3582
epoch 70 step 19040: loss 2.3245
epoch 80 step 21760: loss 2.3034
epoch 90 step 24480: loss 2.2845
epoch 100 step 27200: loss 2.2777
[save] model weights saved to ./checkpoints/ed_standard_weights.pth
=== FINAL SAMPLE (sampling) ===
SRC PROMPT:
JULIET:
GENERATED:
esieh' i ot.
JlROPR:Tu,y acinwlrmntmc!mtgig oyJl,a!
hvc o--eprd o?
LD''slbkhldf'tmnbI?
adembrdad!twr,dsyppii!t,ls-fJwoao!Ig'b?fn!
hssi nwrvn o
et-hu:ru,Ili.JLu,dei!beda!hfi,pfnb
odmsym?oksc hhm,nahu nra wak puc!Ih-r!wi hwfJu-ugatlnrgd'q-w-h;Anl oeudd'
ucl-fs,NU-b nl!Ic:Y
Rtg.Hanfoyyf-u'uim:Itsag,yn
Device: cpu
Vocab size: 66, data length: 1115394
[load] model weights loaded from ./checkpoints/ed_standard_weights.pth
============================================================
SRC PROMPT:
ROMEO:
------------------------------------------------------------
SAMPLE GEN (sampling):
etn i upan'op
hrsl,e'qae, nwrtymnsdbIndttbres
hcubesadigsymo i,rbtorlknwoig;cshdfaru?clt!cldM b
h a dk epcaiflhi!Te,hbl!wm,lc!gvpausuhwok
Rordspie,i uaar'e!B
Jpo!R!R' hwycti!wmsf am,pi!Ot-o sagv-olwm
============================================================
AI要約
学習時は、膨大な文章の一部を入力とし、その続きの文章を出力(正解)として、時間要素を含むパラメータを繰り返し学習します。
運用時は、ユーザーが入力した文章に対して、学習済みのパラメータを用いて、次に続く文章を予測します。
GPT-5 のパラメータ数は約3000億(300B)や約6350億(635B)と言われています。
※ 文章は「時系列(順番)」に進行するため、便宜上、その順序情報を「時間要素」にしています。
