0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AI要素⑧ トランスフォーマー

Last updated at Posted at 2025-09-28

AIの要素技術について記述します。
 

参考

G検定 公式テキスト

マルチヘッドアテンション (Multi-head Attention) [Transformerの部品]【CVMLエキスパートガイド】

理解

2017年に登場したモデル("Attention is All You Need" 論文)

サンプルプログラム

transformer_seq2seq.py

Q/K/V(クエリ・キー・バリュー)を線形変換で作成 → ヘッドに分割 → Scaled Dot-Product Attention
EncoderLayer: Self-Attention + FFN
DecoderLayer: Causal Self-Attention + Cross-Attention(Source-Target Attention) + FFN
位置エンコーディング(正弦波)
パディングマスクとデコーダの因果マスクを両立して適用

学習が文字レベル(char)なので、文字の断片がそれっぽく並ぶ出力になる
文字単位ではなく、BPE/SentencePiece などのサブワードトークンに変えると、単語や句のまとまりを学びやすくなり、“英文っぽさ”が大幅に向上する。ただし、単語単位にすると語彙は数万語になり、実装が複雑で、計算量も多くなる

パラメータ調整

CPUでも回せる軽量設定
d_model = 128
n_heads = 4
d_ff = 256
num_encoder_layers = 3
num_decoder_layers = 3
dropout_rate = 0.1
batch_size = 32
epochs = 8

もう一段強化(CPU/軽GPU)
d_model = 384
n_heads = 8
d_ff = 4 * d_model # 1536
num_encoder_layers = 4
num_decoder_layers = 4
dropout_rate = 0.2
batch_size = 16    # PC 16GB メモリで 16、d_model=256, enc/dec=3/3 ならば 32
epochs = 100

さらに強化(GPU前提)
d_model = 512
n_heads = 8
d_ff = 4 * d_model # 2048
num_encoder_layers = 6
num_decoder_layers = 6
dropout_rate = 0.3
batch_size = 32
epochs = 120

サンプルプログラム

追加ライブラリ
pip install torch
080_transformer_seq2seq.py
# -*- coding: utf-8 -*-
"""
tiny-shakespeare を用いた標準 Transformer(Encoder-Decoder)
- 文字レベル(char-level)
- 各デコーダ層で Self-Attn → Cross-Attn → FFN(標準構成)
- サンプリング生成(temperature / top-k / repetition_penalty)
"""

import os, math, urllib.request, pathlib, random
import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# ハイパーパラメータ
# =========================
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42
random.seed(seed); torch.manual_seed(seed)

data_dir = "./data"
tiny_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
tiny_path = os.path.join(data_dir, "tinyshakespeare.txt")

# モデル&学習(CPUでも回せる設定。epochs を伸ばす。一段強化(CPU/軽GPU)。おおよそのパラメータ数 ~18M)
d_model = 384
n_heads = 6
d_ff = 4 * d_model # 1536
max_len = 512
src_len = 128
tgt_len = 128
batch_size = 16
epochs = 100              # 学習を伸ばす
lr = 3e-4
bos_token_str = "<BOS>"

# 層数(可変)
num_encoder_layers = 4
num_decoder_layers = 4
dropout_rate = 0.2

# 生成(サンプリング)設定のデフォルト
gen_steps = 300
gen_temperature = 0.7
gen_top_k = 20
gen_repetition_penalty = 1.3

# =========================
# データ取得
# =========================
def ensure_tiny_shakespeare():
    pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True)
    if not os.path.exists(tiny_path):
        print("Downloading tiny-shakespeare...")
        urllib.request.urlretrieve(tiny_url, tiny_path)
        print("Saved:", tiny_path)

with torch.no_grad():
    ensure_tiny_shakespeare()

with open(tiny_path, "r", encoding="utf-8") as f:
    text = f.read()

# 文字辞書(+ BOS)
chars = sorted(list(set(text)))
if bos_token_str not in chars:
    chars.append(bos_token_str)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
vocab_size = len(chars)
bos_id = stoi[bos_token_str]

def encode(s: str):
    return torch.tensor([stoi[c] for c in s], dtype=torch.long)

def decode(ids):
    return "".join(itos[int(i)] for i in ids)

# =========================
# ミニ・ローダ
# =========================
def get_batch(batch_size, src_len, tgt_len):
    B = batch_size
    total_len = src_len + tgt_len
    xs = torch.zeros((B, src_len), dtype=torch.long)
    ys = torch.zeros((B, tgt_len), dtype=torch.long)
    for b in range(B):
        i = random.randint(0, len(text) - total_len - 1)
        chunk = text[i : i + total_len]
        xs[b] = encode(chunk[:src_len])
        ys[b] = encode(chunk[src_len:])
    return xs.to(device), ys.to(device)

# =========================
# 位置エンコーディング
# =========================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T = x.size(1)
        return x + self.pe[:T, :]

# =========================
# Multi-Head Attention
# =========================
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, bias: bool = True):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model, bias=bias)
        self.W_k = nn.Linear(d_model, d_model, bias=bias)
        self.W_v = nn.Linear(d_model, d_model, bias=bias)
        self.W_o = nn.Linear(d_model, d_model, bias=bias)

    def _split(self, x):
        B, T, D = x.shape
        return x.view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B,H,T,Dh)

    def _merge(self, x):
        B, H, T, Dh = x.shape
        return x.transpose(1, 2).contiguous().view(B, T, H * Dh)

    def forward(self, q_inp, k_inp, v_inp, attn_mask=None, return_qkv=False):
        Q = self._split(self.W_q(q_inp))
        K = self._split(self.W_k(k_inp))
        V = self._split(self.W_v(v_inp))
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)  # (B,H,Tq,Tk)
        if attn_mask is not None:
            scores = scores + attn_mask
        attn = F.softmax(scores, dim=-1)
        ctx = attn @ V
        out = self.W_o(self._merge(ctx))
        if return_qkv:
            return out, Q, K, V
        return out

# =========================
# エンコーダ層(Self-Attn → FFN)
# =========================
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=256, dropout=0.1):
        super().__init__()
        self.self_mha = MultiHeadAttention(d_model, n_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.ln2 = nn.LayerNorm(d_model)
        self.do = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None, return_qkvs=False):
        attn_out, Q, K, V = self.self_mha(x, x, x, attn_mask, return_qkv=True)
        x = self.ln1(x + self.do(attn_out))
        ff_out = self.ff(x)
        x = self.ln2(x + self.do(ff_out))
        if return_qkvs:
            return x, (Q, K, V)
        return x

# =========================
# デコーダ層(Self-Attn → Cross-Attn → FFN)
# =========================
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=256, dropout=0.1):
        super().__init__()
        self.self_mha = MultiHeadAttention(d_model, n_heads)
        self.ln1 = nn.LayerNorm(d_model)

        self.cross_mha = MultiHeadAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)

        self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.ln3 = nn.LayerNorm(d_model)
        self.do = nn.Dropout(dropout)

    @staticmethod
    def causal_mask(T_q, T_k, device):
        mask = torch.full((T_q, T_k), float("-inf"), device=device)
        mask = torch.triu(mask, diagonal=1)  # 上三角を -inf(未来を隠す)
        return mask.unsqueeze(0).unsqueeze(0)  # (1,1,Tq,Tk)

    def forward(self, x, memory, device, memory_pad_mask=None, return_self_q=False):
        # 1) Masked Self-Attn
        B, T_tgt, _ = x.shape
        self_mask = self.causal_mask(T_tgt, T_tgt, device)
        self_out, Q_self, K_self, V_self = self.self_mha(x, x, x, self_mask, return_qkv=True)
        x = self.ln1(x + self.do(self_out))

        # 2) Cross-Attn
        cross_mask = None
        if memory_pad_mask is not None:
            m = memory_pad_mask.unsqueeze(1).unsqueeze(2)  # (B,1,1,T_src)
            cross_mask = m * torch.finfo(x.dtype).min

        cross_out, Q_cross, K_cross, V_cross = self.cross_mha(x, memory, memory, cross_mask, return_qkv=True)
        x = self.ln2(x + self.do(cross_out))

        # 3) FFN
        ff_out = self.ff(x)
        x = self.ln3(x + self.do(ff_out))

        if return_self_q:
            return x, Q_self
        return x

# =========================
# スタック
# =========================
class EncoderStack(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, num_layers, max_len=4096, dropout=0.1):
        super().__init__()
        self.pe = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])

    def forward(self, src_embed, src_pad_mask=None, return_last_qkv=True):
        x = self.pe(src_embed)
        attn_mask = None
        if src_pad_mask is not None:
            mask = src_pad_mask.unsqueeze(1).unsqueeze(2)  # (B,1,1,T_src)
            attn_mask = mask * torch.finfo(x.dtype).min

        last_Q = last_K = last_V = None
        for i, layer in enumerate(self.layers):
            if return_last_qkv and (i == len(self.layers) - 1):
                x, (last_Q, last_K, last_V) = layer(x, attn_mask, return_qkvs=True)
            else:
                x = layer(x, attn_mask, return_qkvs=False)
        return x, last_K, last_V  # encoder_key, encoder_value(最終層)

class DecoderStack(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, num_layers, max_len=4096, dropout=0.1):
        super().__init__()
        self.pe = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])

    def forward(self, tgt_embed, memory, device, memory_pad_mask=None, return_last_self_q=True):
        x = self.pe(tgt_embed)
        last_Q_self = None
        for i, layer in enumerate(self.layers):
            if return_last_self_q and (i == len(self.layers) - 1):
                x, last_Q_self = layer(x, memory, device, memory_pad_mask, return_self_q=True)
            else:
                x = layer(x, memory, device, memory_pad_mask, return_self_q=False)
        return x, last_Q_self  # decoder_query(最終層 Self-Attn の Q)

# =========================
# 標準 Transformer(ED)
# =========================
class StandardEncoderDecoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=128,
        n_heads=4,
        d_ff=256,
        max_len=4096,
        num_encoder_layers=3,
        num_decoder_layers=3,
        dropout=0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.encoder = EncoderStack(d_model, n_heads, d_ff, num_encoder_layers, max_len=max_len, dropout=dropout)
        self.decoder = DecoderStack(d_model, n_heads, d_ff, num_decoder_layers, max_len=max_len, dropout=dropout)
        self.out_proj = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, src_ids, tgt_ids, src_pad_mask=None):
        src_embed = self.tok_emb(src_ids) * math.sqrt(self.d_model)
        tgt_embed = self.tok_emb(tgt_ids) * math.sqrt(self.d_model)

        memory, encoder_key, encoder_value = self.encoder(src_embed, src_pad_mask, return_last_qkv=True)
        dec_out, decoder_query = self.decoder(tgt_embed, memory, src_ids.device,
                                              memory_pad_mask=src_pad_mask,
                                              return_last_self_q=True)

        logits = self.out_proj(dec_out)  # (B,T_tgt,V)
        aux = {
            "encoder_key": encoder_key,
            "encoder_value": encoder_value,
            "decoder_query": decoder_query,
            "memory": memory,
            "dec_out": dec_out,
        }
        return logits, aux

# =========================
# 学習ユーティリティ
# =========================
def sequence_ce_loss(logits, target_ids, label_smoothing: float = 0.0):
    """
    logits: (B,T,V), target_ids: (B,T)
    次トークン予測:logits[:, :-1] vs target[:, 1:]
    """
    B, T, V = logits.shape
    logits_shift = logits[:, :-1, :].contiguous().view(-1, V)
    target_shift = target_ids[:, 1:].contiguous().view(-1)
    return F.cross_entropy(logits_shift, target_shift, label_smoothing=label_smoothing)

# =========================
# サンプリング生成(temperature / top-k / repetition penalty)
# =========================
@torch.no_grad()
def sample_generate(model,
                    src_prompt: str,
                    steps: int = 300,
                    temperature: float = 0.9,
                    top_k: int | None = 40,
                    repetition_penalty: float = 1.1):
    """
    - greedy ではなくサンプリングで生成
    - repetition_penalty で直近までの出力トークンの確率を少し抑制
    """
    model.eval()
    src_ids = encode(src_prompt).unsqueeze(0).to(device)  # (1, Ts)
    dec_ids = torch.tensor([[bos_id]], dtype=torch.long, device=device)  # (1,1)
    src_pad_mask = None

    for _ in range(steps):
        logits, _ = model(src_ids, dec_ids, src_pad_mask=src_pad_mask)  # (1, Td, V)
        next_logits = logits[:, -1, :]

        # 温度
        next_logits = next_logits / max(temperature, 1e-6)

        # repetition penalty(簡易):出したトークンに対して logit を弱める
        if repetition_penalty and repetition_penalty > 1.0:
            for tok in dec_ids[0].tolist():
                next_logits[0, tok] /= repetition_penalty

        # top-k フィルタ
        if top_k is not None and top_k > 0:
            k = min(top_k, next_logits.size(-1))
            v, _ = torch.topk(next_logits, k=k)
            thresh = v[:, -1].unsqueeze(-1)
            next_logits = torch.where(next_logits < thresh, torch.full_like(next_logits, float("-inf")), next_logits)

        probs = torch.softmax(next_logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)  # サンプリング
        dec_ids = torch.cat([dec_ids, next_id], dim=1)

    gen_ids = dec_ids[0, 1:].tolist()
    return decode(gen_ids)

# ======================================
# セーブ/ロード(重みのみ)
# ======================================
CHECKPOINT_DIR = "./checkpoints"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "ed_standard_weights.pth")

def save_weights(model):
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    torch.save(model.state_dict(), CHECKPOINT_PATH)
    print(f"[save] model weights saved to {CHECKPOINT_PATH}")

def load_weights(model):
    state = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(state)
    model.eval()
    print(f"[load] model weights loaded from {CHECKPOINT_PATH}")

# =========================
# メイン
# =========================
def main():
    print(f"Device: {device}")
    print(f"Vocab size: {vocab_size}, data length: {len(text)}")

    model = StandardEncoderDecoder(
        vocab_size,
        d_model=d_model,
        n_heads=n_heads,
        d_ff=d_ff,
        max_len=max_len,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        dropout=dropout_rate,
    ).to(device)

    mode = "train"   # "train" が学習して保存。"load" にすると保存済みモデルをロード

    if mode == "train":
        # ============================
        # 学習モード:最後に1回だけ保存
        # ============================
        steps_per_epoch = max(1, len(text) // (batch_size * (src_len + tgt_len)))
        PRINTS_PER_EPOCH = 1  # 1なら各エポックに1回、2なら各エポックに2回表示
        print_every = max(1, steps_per_epoch // PRINTS_PER_EPOCH)

        LOG_EVERY_EPOCHS = 10  # 10エポックに1回表示。1なら毎エポック

        global_step = 0

        opt = torch.optim.AdamW(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=epochs * steps_per_epoch, eta_min=1e-5
        )

        for epoch in range(1, epochs + 1):
            model.train()
            running = 0.0
            for it in range(steps_per_epoch):
                src_ids, tgt_ids = get_batch(batch_size, src_len, tgt_len)

                # 教師強制:デコーダ入力は <BOS> + tgt[:-1]
                bos = torch.full((batch_size, 1), bos_id, dtype=torch.long, device=device)
                dec_in = torch.cat([bos, tgt_ids[:, :-1]], dim=1)

                logits, _ = model(src_ids, dec_in, src_pad_mask=None)
                loss = sequence_ce_loss(logits, tgt_ids, label_smoothing=0.1)

                opt.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
                scheduler.step()

                running += loss.item()
                global_step += 1
                if ((it + 1) % print_every == 0) or ((it + 1) == steps_per_epoch):
                    # 直近ウィンドウの実ステップ数(エポック末の端数にも対応)
                    window = (it + 1) % print_every
                    if window == 0:
                        window = print_every

                    total_step = (epoch - 1) * steps_per_epoch + (it + 1)
                    if (epoch % LOG_EVERY_EPOCHS == 0) or (epoch == 1) or (epoch == epochs):
                        print(f"epoch {epoch} step {total_step}: loss {running / window:.4f}")

                        # 各エポックでのサンプル生成(サンプリング)
                        #with torch.no_grad():
                        #    demo_src = "ROMEO:\n"
                        #    out = sample_generate(
                        #        model,
                        #        demo_src,
                        #        steps=200,
                        #        temperature=gen_temperature,
                        #        top_k=gen_top_k,
                        #        repetition_penalty=gen_repetition_penalty,
                        #    )
                        #    print("=" * 60)
                        #    print("SRC PROMPT:\n" + demo_src)
                        #    print("-" * 60)
                        #    print("SAMPLE GEN (sampling):\n" + out)
                        #    print("=" * 60)

                    running = 0.0

        save_weights(model)

        # 最終サンプル
        src_prompt = "JULIET:\n"
        sample = sample_generate(
            model,
            src_prompt,
            steps=gen_steps,
            temperature=gen_temperature,
            top_k=gen_top_k,
            repetition_penalty=gen_repetition_penalty,
        )
        print("\n=== FINAL SAMPLE (sampling) ===")
        print("SRC PROMPT:\n" + src_prompt)
        print("GENERATED:\n" + sample)

    else:
        # ============================
        # 読み込みモード:重みをロードして生成
        # ============================
        load_weights(model)

        demo_src = "ROMEO:\n"
        out = sample_generate(
            model,
            demo_src,
            steps=200,
            temperature=gen_temperature,
            top_k=gen_top_k,
            repetition_penalty=gen_repetition_penalty,
        )
        print("=" * 60)
        print("SRC PROMPT:\n" + demo_src)
        print("-" * 60)
        print("SAMPLE GEN (sampling):\n" + out)
        print("=" * 60)

if __name__ == "__main__":
    main()

結果:学習モード
Device: cpu
Vocab size: 66, data length: 1115394
epoch 1 step 272: loss 3.2069
epoch 10 step 2720: loss 2.7036
epoch 20 step 5440: loss 2.5672
epoch 30 step 8160: loss 2.4917
epoch 40 step 10880: loss 2.4434
epoch 50 step 13600: loss 2.3909
epoch 60 step 16320: loss 2.3582
epoch 70 step 19040: loss 2.3245
epoch 80 step 21760: loss 2.3034
epoch 90 step 24480: loss 2.2845
epoch 100 step 27200: loss 2.2777
[save] model weights saved to ./checkpoints/ed_standard_weights.pth

=== FINAL SAMPLE (sampling) ===
SRC PROMPT:
JULIET:

GENERATED:
 esieh' i ot.
JlROPR:Tu,y acinwlrmntmc!mtgig oyJl,a!
hvc o--eprd o?
LD''slbkhldf'tmnbI?
adembrdad!twr,dsyppii!t,ls-fJwoao!Ig'b?fn!
hssi nwrvn o
et-hu:ru,Ili.JLu,dei!beda!hfi,pfnb
odmsym?oksc hhm,nahu nra wak puc!Ih-r!wi hwfJu-ugatlnrgd'q-w-h;Anl oeudd'
ucl-fs,NU-b nl!Ic:Y
Rtg.Hanfoyyf-u'uim:Itsag,yn

結果:読み込みモード
Device: cpu
Vocab size: 66, data length: 1115394
[load] model weights loaded from ./checkpoints/ed_standard_weights.pth
============================================================
SRC PROMPT:
ROMEO:

------------------------------------------------------------
SAMPLE GEN (sampling):
 etn i upan'op
hrsl,e'qae, nwrtymnsdbIndttbres
hcubesadigsymo i,rbtorlknwoig;cshdfaru?clt!cldM b
h a dk epcaiflhi!Te,hbl!wm,lc!gvpausuhwok
Rordspie,i uaar'e!B
Jpo!R!R' hwycti!wmsf am,pi!Ot-o sagv-olwm
============================================================

AI要約

学習時は、膨大な文章の一部を入力とし、その続きの文章を出力(正解)として、時間要素を含むパラメータを繰り返し学習します。

運用時は、ユーザーが入力した文章に対して、学習済みのパラメータを用いて、次に続く文章を予測します。

GPT-5 のパラメータ数は約3000億(300B)や約6350億(635B)と言われています。

※ 文章は「時系列(順番)」に進行するため、便宜上、その順序情報を「時間要素」にしています。

 

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?