4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ゼロから作るLLM Part2: Transformerの実装

Last updated at Posted at 2025-12-04

はじめに

Part1ではBPEトークナイザーを実装しました。
今回は、そのトークナイザーで処理したデータを学習・推論するための Transformerを実装します。

Transformerとは

Transformerは「Attention Is All You Need」で提案されたアーキテクチャであり、GPTなどのLLMのベースとなっています。以下の図はML Visualsから引用したTransformerの全体構造です。各コンポーネントの役割については次の章で見ていきます。

ML Visuals (1).png

なおTransformerの仕組みは以下のサイトでめちゃくちゃわかりやすく可視化されているので, まずこちらを見てみることをおすすめします。実装する際にもかなり参考にしました。

Transformerのコンポーネント実装

Transformerの各コンポーネントを実装ベースで解説していきます。コード全体はこちらに置いています。

1. Token Embedding

トークンIDから埋め込みベクトルを得るためのコンポーネントで、Input EmbeddingおよびOutput Embeddingで使用します。

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, model_dim, padding_idx=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, model_dim, padding_idx=padding_idx)
        self.model_dim = model_dim

    def forward(self, inputs):
        return self.embedding(inputs) * sqrt(self.model_dim)

nn.Embedding is Whatとなりますが、要するに以下のコードと同じです。

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, model_dim):
        super().__init__()
        self.embedding = nn.Parameter(torch.randn(size=(vocab_size, model_dim)))
        self.model_dim = model_dim

    def forward(self, inputs):
        return self.embedding[inputs] * sqrt(self.model_dim)

例えばvocab_size=5, model_dim=8 の場合, dogの埋め込みベクトルは以下の図のような手順で得られます(図中の数値は乱数で生成したものであり, 特に意味はないです)。

token_embedding.png

2. Positional Encoding

Transformerは再帰構造を持たないため、単語の位置情報をPositional Encodingを用いて埋め込みベクトルに加算します。

class PositionalEncoding(nn.Module):
    def __init__(self, model_dim, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        # Positional Encoding行列
        pe = torch.zeros(max_len, model_dim)
        # 位置インデックス (max_len, 1)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        denominator = torch.exp(
            torch.arange(0, model_dim, 2).float() * (-math.log(10000.0) / model_dim)
        )
        # 偶数インデックス
        pe[:, 0::2] = torch.sin(position * denominator)
        # 奇数インデックス
        pe[:, 1::2] = torch.cos(position * denominator)
        # バッチ次元を追加
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, inputs):
        # 入力: (batch_size, seq_len, model_dim)
        seq_len = inputs.size(1)
        # Positional Encodingを加算
        inputs = inputs + self.pe[:, :seq_len, :]
        return self.dropout(inputs)

位置pos、次元iの位置エンコーディングは以下の式で表されます:

$$PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\mathrm{model}}}}\right)$$

$$PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\mathrm{model}}}}\right)$$

${10000^{2i/d_{\mathrm{model}}}}$の部分は指数計算が不安定なので, 対数を使って以下のように変換しています:

$$10000^{2i/d_{\text{model}}} = \exp(\log(10000^{2i/d_{\text{model}}})) = \exp\left(\frac{2i}{d_{\text{model}}} \cdot \log(10000)\right)$$

最終的に得られる位置エンコーディングpeを可視化したものが以下の図です.

positional_encoding.png

3. Multi-Head Attention

Attentionは「どの単語に注目すべきか」を計算する仕組みです。Transformerでは、複数の「ヘッド」に分割して異なる観点からトークン間の関係を学習するMulti-Head Attentionを使用します。

Attentionの計算式

各ヘッドで以下の計算を行います:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

実装

class MultiheadAttention(nn.Module):
    def __init__(self, model_dim, head=8, dropout=0.1):
        super().__init__()
        if model_dim % head != 0:
            raise ValueError("model_dim must be divisible by head")
        # Q(query), K(key), V(value)の重みの次元
        self.w_qkv_dim = model_dim // head
        self.head = head
        # Q, K, V: (model_dim, self.w_qkv_dim * head = model_dim)の線形変換
        self.q_proj = nn.Linear(model_dim, self.w_qkv_dim * head, bias=True)
        self.k_proj = nn.Linear(model_dim, self.w_qkv_dim * head, bias=True)
        self.v_proj = nn.Linear(model_dim, self.w_qkv_dim * head, bias=True)
        # 出力: (self.w_qkv_dim * head = model_dim, model_dim)の線形変換
        self.out_proj = nn.Linear(self.w_qkv_dim * head, model_dim, bias=True)
        self.dropout = nn.Dropout(p=dropout)
        self.last_attention_weights = None

    def forward(self, inputs, encoder_out=None, mask=None):
        # input: (batch_size, seq_len, model_dim)
        batch_size, seq_len, _ = inputs.shape

        # (batch_size, seq_len, model_dim)
        # → (batch_size, seq_len, model_dim)
        query = self.q_proj(inputs)
        key = self.k_proj(encoder_out if encoder_out is not None else inputs)
        value = self.v_proj(encoder_out if encoder_out is not None else inputs)

        # 全ヘッドのAttention計算を行列演算で行うために変形·並び替え
        # (batch_size, seq_len, model_dim)
        # → (batch_size, seq_len, head, w_qkv_dim)
        # → (batch_size, head, seq_len, w_qkv_dim)
        query = query.view(batch_size, seq_len, self.head, self.w_qkv_dim).transpose(
            1, 2
        )
        key = key.view(batch_size, -1, self.head, self.w_qkv_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.head, self.w_qkv_dim).transpose(1, 2)

        # Q x K^T: (batch_size, head, seq_len, seq_len)
        # key.transpose(-2, -1): (batch_size, head, w_qkv_dim, seq_len)
        scores = torch.matmul(query, key.transpose(-2, -1))

        # Kの次元の平方根で割る
        scores /= sqrt(self.w_qkv_dim)

        # マスク処理を適用(次のセクションで説明)
        if mask is not None:
            if mask.dim() == 2:
                if encoder_out is not None:  # Cross-attention
                    mask = mask.unsqueeze(1).unsqueeze(2)
                else:  # Self-attention
                    mask = mask.unsqueeze(0).unsqueeze(0)
            elif mask.dim() == 3:
                mask = mask.unsqueeze(1)
            scores = scores.masked_fill(~mask, -1e9)

        # softmax: (batch_size, head, seq_len, seq_len)
        attention_weights = F.softmax(scores, dim=-1)

        self.last_attention_weights = attention_weights

        # dropout
        attention_weights = self.dropout(attention_weights)

        # z: (batch_size, head, seq_len, w_qkv_dim)
        z = torch.matmul(attention_weights, value)

        # (batch_size, head, seq_len, w_qkv_dim)
        # → (batch_size, seq_len, head, w_qkv_dim)
        # cf: https://qiita.com/kenta1984/items/d68b72214ce92beebbe2
        z = z.transpose(1, 2).contiguous()

        # (batch_size, seq_len, head, w_qkv_dim)
        # → (batch_size, seq_len, head * w_qkv_dim)
        z = z.view(batch_size, seq_len, self.head * self.w_qkv_dim)

        # (batch_size, seq_len, model_dim)
        output = self.out_proj(z)

        # dropout
        output = self.dropout(output)
        return output

Self-Attention と Cross-Attention

Multi-Head Attentionは、Q, K, Vの入力元によって2つの使い方があります:

  • Self-Attention: Q, K, V をすべて同じ入力から生成(Encoder·Decoderで使用)
  • Cross-Attention: Q は Decoder、K と V は Encoder から生成(Decoderで使用)

実装では encoder_out パラメータの有無で切り替えています:

# Self-Attention: encoder_out=None
key = self.k_proj(inputs)
value = self.v_proj(inputs)

# Cross-Attention: encoder_out が指定されている
key = self.k_proj(encoder_out)
value = self.v_proj(encoder_out)

4. マスク処理

Transformerは全単語を同時に処理するため、「見てはいけない情報」を明示的にマスクする必要があります。マスクは主に2種類あります。

Padding Mask

バッチ学習では、短い文に <PAD> を追加して長さを揃えます。このパディング部分をAttentionの計算から除外する必要があります。

padding_mask = (src != PAD_IDX)  # True: 有効なトークン, False: <PAD>

Causal Mask

Decoderは出力文を生成する際、未来の単語(まだ生成していない単語)を見てはならないので, マスクしてあげる必要があります。

例: I like sushi を生成する場合の各位置で見てよい/見てはいけない範囲

位置 入力 見てよい 見てはいけない
0 <BOS> <BOS> I, like, sushi
1 I <BOS>, I like, sushi
2 like <BOS>, I, like sushi

Causal Maskの構造(下三角行列):

       <BOS>  I  like  sushi
<BOS>  [ 1    0   0    0 ]
I      [ 1    1   0    0 ]
like   [ 1    1   1    0 ]
sushi  [ 1    1   1    1 ]

1: 見てよい, 0: 見てはいけない

実装では torch.tril で下三角行列を生成しています:

def create_causal_mask(seq_len, device=None):
    return torch.tril(torch.ones(seq_len, seq_len, device=device)).bool()

5. Feed Forward Network

Attentionの後に適用される全結合層です。

class FeedForward(nn.Module):
    def __init__(self, model_dim, dropout=0.1):
        super().__init__()
        self.linear_1 = nn.Linear(model_dim, model_dim * 4)
        self.activation_func = nn.ReLU()
        self.linear_2 = nn.Linear(model_dim * 4, model_dim)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, inputs):
        middle = self.linear_1(inputs)
        relu = self.activation_func(middle)
        output = self.linear_2(relu)
        # dropout
        output = self.dropout(output)
        return output

中間層のサイズをモデル次元の4倍にしてからもとに戻しています。

6. Encoder Layer

Self-Attention、Feed Forward、Layer Normalization、Residual Connection(残差結合, Transformerの構成図におけるブロックをスキップしている部分)を組み合わせます。

class EncoderLayer(nn.Module):
    def __init__(self, model_dim):
        super().__init__()
        self.attention = MultiheadAttention(model_dim)
        self.normalizer_1 = nn.LayerNorm(model_dim)
        self.feed_forward = FeedForward(model_dim)
        self.normalizer_2 = nn.LayerNorm(model_dim)

    def forward(self, inputs, mask=None):
        # 1. Self-Attention + Residual Connection + Layer Normalization
        attention_out = self.attention(inputs, mask=mask)
        normalized_1 = self.normalizer_1(inputs + attention_out)

        # 2. Feed Forward + Residual Connection + Layer Normalization
        feed_forward_out = self.feed_forward(normalized_1)
        normalized_2 = self.normalizer_2(normalized_1 + feed_forward_out)
        return normalized_2

Encoder

Encoderは複数のEncoder Layerを重ねて構成します。

class Encoder(nn.Module):
    def __init__(self, model_dim, encoder_num):
        super().__init__()
        self.encoders = nn.ModuleList(
            [EncoderLayer(model_dim) for _ in range(encoder_num)]
        )

    def forward(self, inputs, mask=None):
        for encoder_layer in self.encoders:
            inputs = encoder_layer(inputs, mask)
        return inputs

7. Decoder Layer

Decoder LayerはEncoder LayerにCross-Attentionを追加した構造です。

class DecoderLayer(nn.Module):
    def __init__(self, model_dim):
        super().__init__()
        self.masked_attention = MultiheadAttention(model_dim)
        self.normalizer_1 = nn.LayerNorm(model_dim)
        self.attention = MultiheadAttention(model_dim)
        self.normalizer_2 = nn.LayerNorm(model_dim)
        self.feed_forward = FeedForward(model_dim)
        self.normalizer_3 = nn.LayerNorm(model_dim)

    def forward(self, inputs, encoder_out, tgt_mask=None, src_mask=None):
        # 1. Masked Self-Attention(未来の単語を見ない)
        masked_attention_out = self.masked_attention(inputs, mask=tgt_mask)
        normalized_1 = self.normalizer_1(inputs + masked_attention_out)

        # 2. Cross-Attention(Encoderの出力を参照)
        attention_out = self.attention(normalized_1, encoder_out, mask=src_mask)
        normalized_2 = self.normalizer_2(normalized_1 + attention_out)

        # 3. Feed Forward
        feed_forward_out = self.feed_forward(normalized_2)
        normalized_3 = self.normalizer_3(normalized_2 + feed_forward_out)
        return normalized_3

Decoder

Decoderも複数のDecoder Layerを重ねて構成されます。

class Decoder(nn.Module):
    def __init__(self, model_dim, decoder_num):
        super().__init__()
        self.decoders = nn.ModuleList(
            [DecoderLayer(model_dim) for _ in range(decoder_num)]
        )

    def forward(self, inputs, encoder_out, tgt_mask=None, src_mask=None):
        for decoder_layer in self.decoders:
            inputs = decoder_layer(inputs, encoder_out, tgt_mask, src_mask)
        return inputs

Transformerの実装

これらを組み合わせて Transformer クラスを作ります。

class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        model_dim,
        encoder_num,
        decoder_num,
        padding_idx=None,
    ):
        super().__init__()
        self.src_embedding = TokenEmbedding(src_vocab_size, model_dim, padding_idx)
        self.tgt_embedding = TokenEmbedding(tgt_vocab_size, model_dim, padding_idx)
        self.positional_encoding = PositionalEncoding(model_dim)
        self.encoder = Encoder(model_dim, encoder_num)
        self.decoder = Decoder(model_dim, decoder_num)
        self.decoder_proj = nn.Linear(model_dim, tgt_vocab_size)

    def forward(
        self,
        source_tokens,
        target_tokens,
        encoder_src_mask=None,
        decoder_src_mask=None,
        tgt_mask=None,
    ):
        # (batch_size, source_len, model_dim)
        source_embed = self.src_embedding(source_tokens)
        source_embed = self.positional_encoding(source_embed)
        # (batch_size, target, model_dim)
        target_embed = self.tgt_embedding(target_tokens)
        target_embed = self.positional_encoding(target_embed)

        encoder_out = self.encoder(source_embed, encoder_src_mask)
        decoder_out = self.decoder(
            target_embed, encoder_out, tgt_mask, decoder_src_mask
        )
        output = self.decoder_proj(decoder_out)
        return output

まとめ

これでTransformerモデルの実装が完成しました。
Part3 ではこのモデルを使って英日翻訳タスクを解いていきます。

参考

ゼロから作るLLM - 目次

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?