2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ゼロから作るLLM Part3: Transformerで英日翻訳

Last updated at Posted at 2025-12-21

はじめに

Part2ではTransformerモデルを実装しました。
今回は実装したTransformerを使って英日翻訳タスクを解いていきます。

今回やること

  • データセット準備とトークナイザー訓練
  • 訓練パイプラインの構築
  • モデルの訓練
  • 推論(Beam Search、Greedy Decoding)

PyTorch Lightning について

本プロジェクトでは訓練ループの実装に PyTorch Lightning を使用しています。

PyTorch Lightning とは

PyTorchで書くのがやや億劫な訓練ループの抽象化をいい感じに行ってくれるライブラリです。ここでは詳しい説明は割愛しますが、以下のページを見ると雰囲気は掴めると思います。最初は訓練ループを手書きしてましたが、途中でこのライブラリの存在を知ったのでClaude Codeに書き換えてもらいました。若干変更コストは高いのでできれば最初から導入したかったです。

プロジェクト構成

scripts/
  config.py          # 設定管理
  prepare.py         # トークナイザー訓練
  train.py           # モデル訓練
  eval.py            # モデル評価
  inference.py       # 推論
  lightning_module.py # PyTorch Lightningモジュール

model/
  transformer.py     # Transformerモデル

utils/
  training_pipeline.py  # 訓練パイプライン
  inference_pipeline.py # 推論パイプライン
  decoding_strategy.py  # デコード戦略
  collate.py # collate関数

コード全体: GitHub

1. データセット準備とトークナイザー訓練

データセット

計算資源に限りがあったため、小規模な英日対訳データセットを使用して学習を行いました。

トークナイザー訓練

def prepare_transformer_tokenizers(
    dataset_name,
    vocab_size,
    src_lang,
    tgt_lang,
    src_column,
    tgt_column,
):
    from datasets import load_dataset
    from tokenizer.bpe import BPE

    dataset_dir = dataset_name.replace("/", "_")
    tokenizer_dir = f"data/tokenizers/{dataset_dir}"

    src_tokenizer_path = f"{tokenizer_dir}/{src_lang}_bpe.pkl"
    tgt_tokenizer_path = f"{tokenizer_dir}/{tgt_lang}_bpe.pkl"

    dataset = load_dataset(dataset_name, split="train")

    # 英語トークナイザーを訓練
    src_texts = [ex[src_column] for ex in dataset]
    src_tokenizer = BPE(pattern=GPT2_PATTERN if src_lang == "en" else None)
    src_tokenizer.train(src_texts, vocab_size=vocab_size)

    # 日本語トークナイザーを訓練
    tgt_texts = [ex[tgt_column] for ex in dataset]
    tgt_tokenizer = BPE(pattern=None)
    tgt_tokenizer.train(tgt_texts, vocab_size=vocab_size)

    # 保存
    os.makedirs(tokenizer_dir, exist_ok=True)
    src_tokenizer.save(src_tokenizer_path)
    tgt_tokenizer.save(tgt_tokenizer_path)

    return tokenizer_dir

2. 訓練パイプラインの構築

DataModule

データセットの事前処理を行なうモジュールです。

class TransformerDataModule(L.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tokenized_datasets = None
        self.src_tokenizer = None
        self.tgt_tokenizer = None

    def setup(self, stage: Optional[str] = None):
        # トークナイザーをロード
        dataset_dir = self.config.data.dataset_name.replace("/", "_")
        tokenizer_dir = f"{self.config.tokenizer_dir}/{dataset_dir}"

        self.src_tokenizer = BPE.load(f"{tokenizer_dir}/{self.config.data.src_lang}_bpe.pkl")
        self.tgt_tokenizer = BPE.load(f"{tokenizer_dir}/{self.config.data.tgt_lang}_bpe.pkl")

        # 語彙サイズをconfigに設定
        self.config.model.src_vocab_size = self.src_tokenizer.get_vocab_size()
        self.config.model.tgt_vocab_size = self.tgt_tokenizer.get_vocab_size()

        # データセットをロードして分割
        dataset = load_dataset(self.config.data.dataset_name, split="train")
        train_rest = dataset.train_test_split(test_size=0.15, seed=42)
        val_test = train_rest["test"].train_test_split(test_size=0.33, seed=42)
        train_val = DatasetDict({
            "train": train_rest["train"],
            "val": val_test["train"]
        })

        # トークン化
        def preprocess_batch(batch):
            src_ids = []
            for text in batch[self.config.data.src_column]:
                ids = self.src_tokenizer.encode(text, add_special_tokens=True)
                if len(ids) > self.config.data.max_length:
                    ids = ids[:self.config.data.max_length - 1] + [
                        self.src_tokenizer.special_tokens["<EOS>"]
                    ]
                src_ids.append(ids)

            tgt_ids = []
            for text in batch[self.config.data.tgt_column]:
                ids = self.tgt_tokenizer.encode(text, add_special_tokens=True)
                if len(ids) > self.config.data.max_length:
                    ids = ids[:self.config.data.max_length - 1] + [
                        self.tgt_tokenizer.special_tokens["<EOS>"]
                    ]
                tgt_ids.append(ids)

            return {"src": src_ids, "tgt": tgt_ids}

        self.tokenized_datasets = train_val.map(
            preprocess_batch,
            batched=True,
            batch_size=1000,
            remove_columns=train_val["train"].column_names,
            desc="Tokenizing dataset",
        )

        self.tokenized_datasets.set_format(type="torch", columns=["src", "tgt"])

    def train_dataloader(self):
        return DataLoader(
            self.tokenized_datasets["train"],
            batch_size=self.config.data.batch_size,
            shuffle=True,
            collate_fn=collate,
            num_workers=self.config.data.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.tokenized_datasets["val"],
            batch_size=self.config.data.batch_size,
            shuffle=False,
            collate_fn=collate,
            num_workers=self.config.data.num_workers,
        )

Collate関数

可変長の文字列をバッチ化する際、短い文字列に <PAD> を追加して長さを揃えます。

from torch.nn.utils.rnn import pad_sequence

def collate(batch, pad_id=0):
    src_tensors = [item["src"] for item in batch]
    tgt_tensors = [item["tgt"] for item in batch]

    # パディング
    src_padded = pad_sequence(src_tensors, batch_first=True, padding_value=pad_id)
    tgt_padded = pad_sequence(tgt_tensors, batch_first=True, padding_value=pad_id)

    # マスクを生成
    src_mask = src_padded != pad_id
    tgt_mask = tgt_padded != pad_id

    return {
        "src": src_padded,
        "tgt": tgt_padded,
        "src_mask": src_mask,
        "tgt_mask": tgt_mask,
    }

collate関数の処理の例:

入力:
  src: [[1, 2, 3], [4, 5]]
  tgt: [[6, 7], [8, 9, 10]]

出力:
  src: [[1, 2, 3], [4, 5, 0]]  # 0: <PAD>
  tgt: [[6, 7, 0], [8, 9, 10]]
  src_mask: [[True, True, True], [True, True, False]]
  tgt_mask: [[True, True, False], [True, True, True]]

LightningModule

訓練ループをPyTorch Lightningで抽象化したモジュールです。

class TransformerLightningModule(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.save_hyperparameters(config.to_dict())
        self.config = config
        self.model = Transformer(config.model)
        self.criterion = nn.CrossEntropyLoss(
            ignore_index=config.data.pad_idx,
            label_smoothing=config.training.label_smoothing,
        )

    def _shared_step(self, batch, batch_idx):
        src = batch["src"]
        tgt = batch["tgt"]
        src_mask = batch["src_mask"]
        tgt_padding_mask = batch["tgt_mask"]

        # Encoder用のマスク(padding mask)
        src_mask_expanded = src_mask.unsqueeze(1) & src_mask.unsqueeze(2)

        # 入力と出力を1トークンずらす
        tgt_input = tgt[:, :-1]   # <BOS> I like sushi
        tgt_output = tgt[:, 1:]   # I like sushi <EOS>

        # Decoder用のマスク(causal mask + padding mask)
        tgt_len = tgt_input.size(1)
        causal_mask = torch.tril(torch.ones(tgt_len, tgt_len, device=tgt.device)).bool()
        tgt_padding = tgt_padding_mask[:, :-1]
        tgt_input_mask = (
            causal_mask.unsqueeze(0)
            & tgt_padding.unsqueeze(1)
            & tgt_padding.unsqueeze(2)
        )

        # Forward
        output = self.model(
            src,
            tgt_input,
            encoder_src_mask=src_mask_expanded,
            decoder_src_mask=src_mask,
            tgt_mask=tgt_input_mask,
        )

        # Loss計算
        output = output.reshape(-1, output.size(-1))
        tgt_output = tgt_output.reshape(-1)
        loss = self.criterion(output, tgt_output)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self._shared_step(batch, batch_idx)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._shared_step(batch, batch_idx)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.config.optimizer.initial_lr,
            betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2),
            eps=self.config.optimizer.adam_epsilon,
        )

        # 学習率スケジューラ
        def lr_lambda(step):
            step = max(step, 1)
            model_dim = self.config.model.model_dim
            warmup_steps = self.config.optimizer.warmup_steps
            return (model_dim ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5)

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            },
        }

Optimizer と Scheduler

Attention Is All You Needの論文で提案されている学習率スケジューラを実装しています。

$$\text{lr} = d_{\text{model}}^{-0.5} \cdot \min(\text{step}^{-0.5}, \text{step} \cdot \text{warmup_steps}^{-1.5})$$

def lr_lambda(step):
    step = max(step, 1)
    model_dim = config.model.model_dim
    warmup_steps = config.optimizer.warmup_steps
    return (model_dim ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

3. モデルの訓練

設定

Configクラスでハイパーパラメータを一元管理しています。

Config構造

@dataclass
class Config:
    data: DataConfig           # データセットの設定
    optimizer: OptimizerConfig # Optimizerの設定
    model: TransformerModelConfig # Transformerモデルの構造設定
    training: TrainingConfig   # 訓練設定
    inference: InferenceConfig # 推論設定
    modal: ModalConfig         # Modal実行環境の設定
    wandb: WandbConfig         # Weights & Biases の設定

    run_name: str = None
    checkpoint_dir: str = "/vol/runs"
    tokenizer_dir: str = "/vol/tokenizers"

各設定の詳細

DataConfig - データセットの設定

@dataclass
class DataConfig:
    dataset_name: str       # データセット名
    batch_size: int         # バッチサイズ
    max_length: int         # 最大系列長
    num_workers: int
    pad_idx: int
    vocab_size: int = 8000  # 語彙サイズ
    src_lang: str = "en"    # 翻訳元言語
    tgt_lang: str = "ja"    # 翻訳先言語
    src_column: str = "en_sentence"  # データセットの入力カラム名
    tgt_column: str = "ja_sentence"  # データセットの出力カラム名

OptimizerConfig - Optimizerの設定

@dataclass
class OptimizerConfig:
    warmup_steps: int
    adam_beta1: float
    adam_beta2: float
    adam_epsilon: float
    initial_lr: float = 1.0  # 学習率の初期値
    optimizer_type: str = "Adam"
    scheduler_type: str = "inverse_sqrt"  # 学習率スケジューラのタイプ

TransformerModelConfig - Transformerモデルの構造設定

@dataclass
class TransformerModelConfig:
    model_type: Literal["transformer"] = "transformer"
    src_vocab_size: int = None  # 入力語彙サイズ(訓練時に自動設定)
    tgt_vocab_size: int = None  # 出力語彙サイズ(訓練時に自動設定)
    model_dim: int = 256        # モデルの次元数
    encoder_layers: int = 4     # Encoderの層数
    decoder_layers: int = 4     # Decoderの層数
    num_heads: int = 8          # AttentionのHead数
    feedforward_dim: int = None # FFNの中間層サイズ
    dropout: float = 0.1        # Dropout率
    activation: str = "relu"    # 活性化関数
    max_seq_len: int = 5000
    padding_idx: int = 0

論文よりもパラメータ数を減らしています。

TrainingConfig - 訓練設定

@dataclass
class TrainingConfig:
    num_epochs: int
    label_smoothing: float
    early_stopping_patience: int
    gradient_clip_val: float = 1.0
    precision: str = "32-true"
    accumulate_grad_batches: int = 1
    val_check_interval: float = 1.0

InferenceConfig - 推論設定

@dataclass
class InferenceConfig:
    # 特殊トークン
    pad_idx: int = 0
    unk_idx: int = 1
    bos_idx: int = 2
    eos_idx: int = 3

    # Beam Search設定
    beam_size: int = 5
    length_penalty: float = 0.6
    max_output_offset: int = 10

    max_gen_len: int = 100

ModalConfig - Modal実行環境の設定

@dataclass
class ModalConfig:
    gpu_type: str = "L40S"
    timeout_hours: int = 12
    volume_name: str = "XXX"
    secret_name: str = "XXX"

訓練にはModal.comのGPUを使用しています。

WandbConfig - Weights & Biases の設定

@dataclass
class WandbConfig:
    project: str = "XXX"
    entity: str = None
    enabled: bool = True
    log_model: bool = False

学習曲線のモニタリングにWeights & Biasesを使用しています。
GPUのメトリクスもモニタリングでき、スマホからでも進捗を見ることができるのでとても便利です。学生だと無料でライセンス認証できます。

訓練実行

まず、CPUインスタンス上でデータセットを事前処理してキャッシュに保存します。
料金が高いGPUインスタンスの使用時間を少しでも減らすための工夫です。

@app.function(
    image=image,
    volumes={"/vol": volume},
    timeout=1800,
)
def prepare_dataset(config: Config):
    dataset_dir = config.data.dataset_name.replace("/", "_")
    cache_path = f"/vol/processed/{dataset_dir}"

    if os.path.exists(cache_path):
        print(f"Data already prepared at {cache_path}")
        return

    data_module = get_data_module(config)
    data_module.prepare_data()
    data_module.setup(stage="fit")
    volume.commit()

次に、訓練を実行します:

@app.function(
    image=image,
    gpu="L40S",
    volumes={"/vol": volume},
    timeout=3600 * 12,
    secrets=[modal.Secret.from_name("wandb-secret")],
)
def train(config: Config):
    torch.set_float32_matmul_precision("high")

    # キャッシュからデータを読み込み
    data_module = get_data_module(config)
    data_module.load_from_cache()

    pl_module = TransformerLightningModule(config)

    # Callbacks設定
    callbacks = [
        EarlyStopping(
            monitor="val_loss",
            patience=config.training.early_stopping_patience,
            mode="min",
        ),
        ModelCheckpoint(
            dirpath=f"{config.checkpoint_dir}/{config.run_name}",
            filename="best_model",
            monitor="val_loss",
            mode="min",
            save_top_k=1,
        ),
    ]

    # Logger設定
    logger = WandbLogger(
        project=config.wandb.project,
        name=config.run_name,
    )

    # Trainer設定
    trainer = L.Trainer(
        max_epochs=config.training.num_epochs,
        callbacks=callbacks,
        logger=logger,
        gradient_clip_val=config.training.gradient_clip_val,
        precision=config.training.precision,
    )

    # 訓練開始
    trainer.fit(pl_module, datamodule=data_module)

    volume.commit()

以下のような学習曲線を描きました。検証データに対する損失が5.0あたりまで下がった後再び上昇し、過学習を起こしてしまっているのがわかります。

スクリーンショット 2025-12-21 22.40.47.png

4. 推論

デコード戦略

翻訳時に出力文を生成するロジックを書いていきます。以下に挙げる記事を参考に実装しています。

Greedy Decoding

各ステップで最も確率が高いトークンを選択します。

class GreedyDecoding(DecodingStrategy):
    def decode(self, model, src_tokens, src_mask, max_len):
        device = src_tokens.device
        batch_size = src_tokens.size(0)
        context = model.prepare_context(src_tokens, src_mask)
        bos_idx = self.config.bos_idx
        eos_idx = self.config.eos_idx
        results = []

        for batch_idx in range(batch_size):
            output_tokens = [bos_idx]
            for _ in range(max_len):
                tgt_input = torch.tensor([output_tokens], device=device)
                context_subset = {
                    k: v[batch_idx:batch_idx + 1] for k, v in context.items()
                }
                logits = model.generate_next_token(tgt_input, context_subset)
                # 確率が最大のトークン
                next_token = logits.argmax(dim=-1).item()
                output_tokens.append(next_token)
                if next_token == eos_idx:
                    break
            results.append(torch.tensor(output_tokens, device=device))
        return results

Beam Search

複数の候補(ビーム)を保持しながら探索します。

class BeamSearch(DecodingStrategy):
    def decode(self, model, src_tokens, src_mask, max_len):
        device = src_tokens.device
        batch_size = src_tokens.size(0)
        beam_size = self.config.beam_size
        context = model.prepare_context(src_tokens, src_mask)
        bos_idx = self.config.bos_idx
        eos_idx = self.config.eos_idx

        # 各バッチに対してビームを初期化
        beams = [
            [(torch.tensor([bos_idx], device=device), 0.0)]
            for _ in range(batch_size)
        ]

        for _ in range(max_len):
            all_candidates = []
            for batch_idx in range(batch_size):
                candidates = []
                for seq, score in beams[batch_idx]:
                    if seq[-1].item() == eos_idx:
                        candidates.append((seq, score))
                        continue

                    tgt_input = seq.unsqueeze(0)
                    context_subset = {
                        k: v[batch_idx:batch_idx + 1] for k, v in context.items()
                    }
                    logits = model.generate_next_token(tgt_input, context_subset)
                    log_probs = F.log_softmax(logits, dim=-1)
                    top_log_probs, top_indices = log_probs.topk(beam_size)

                    for k in range(beam_size):
                        new_seq = torch.cat([seq, top_indices[0, k].unsqueeze(0)])
                        new_score = score + top_log_probs[0, k].item()
                        candidates.append((new_seq, new_score))

                # 長さペナルティを適用してソート
                candidates = sorted(
                    candidates,
                    key=lambda x: x[1] / ((5 + len(x[0])) / 6) ** self.config.length_penalty,
                    reverse=True,
                )[:beam_size]
                all_candidates.append(candidates)

            beams = all_candidates

            # 全ビームが終了したらループを抜ける
            if all(all(seq[-1].item() == eos_idx for seq, _ in beam) for beam in beams):
                break

        # 各バッチで最も高評価のシーケンスを返す
        results = []
        for beam in beams:
            best_seq, _ = beam[0]
            results.append(best_seq)
        return results

推論パイプライン

def translate_sentence(model, sentence, src_tokenizer, tgt_tokenizer, config, strategy="beam"):
    results = translate_batch(
        model, [sentence], src_tokenizer, tgt_tokenizer, config, strategy
    )
    return results[0]

def translate_batch(model, sentences, src_tokenizer, tgt_tokenizer, config, strategy="beam"):
    model.eval()
    device = next(model.parameters()).device

    # 入力をトークン化
    src_ids = []
    for sentence in sentences:
        ids = src_tokenizer.encode(sentence, add_special_tokens=True)
        if len(ids) > config.data.max_length:
            ids = ids[:config.data.max_length - 1] + [
                src_tokenizer.special_tokens["<EOS>"]
            ]
        src_ids.append(ids)

    max_src_len = max(len(ids) for ids in src_ids)
    src_tokens = torch.zeros(len(sentences), max_src_len, dtype=torch.long)
    src_mask = torch.zeros(len(sentences), max_src_len, dtype=torch.bool)

    for i, ids in enumerate(src_ids):
        src_tokens[i, :len(ids)] = torch.tensor(ids)
        src_mask[i, :len(ids)] = True

    src_tokens = src_tokens.to(device)
    src_mask = src_mask.to(device)

    max_output_len = max_src_len + config.inference.max_output_offset

    # デコード戦略を選択
    if strategy == "beam":
        decoder = BeamSearch(config.inference)
    elif strategy == "greedy":
        decoder = GreedyDecoding(config.inference)
    else:
        raise ValueError(f"Unknown strategy: {strategy}")

    # 推論実行
    with torch.no_grad():
        output_seqs = decoder.decode(model, src_tokens, src_mask, max_output_len)

    # トークンIDをテキストに変換
    translations = []
    for seq in output_seqs:
        tokens = seq.tolist()
        if config.inference.bos_idx in tokens:
            tokens = tokens[1:]
        if config.inference.eos_idx in tokens:
            eos_pos = tokens.index(config.inference.eos_idx)
            tokens = tokens[:eos_pos]

        translation = tgt_tokenizer.decode(tokens, skip_special_tokens=True)
        translations.append(translation)

    return translations

翻訳結果

以下のCSVにテストデータ(訓練にも検証にも使用していない初見のデータ)に対して翻訳を行った結果をまとめています。デコード戦略にはBeam Searchを使用しています。refが正解ラベル、hypが翻訳結果です。
短文は割と頑張っていますが、少し長い文になると支離滅裂な翻訳になってしまいました。そもそもデータセットの規模が2万件と非常に小さく、モデルのパラメータ数を削ったりしているので妥当な結果ではありますが、Transformerの雰囲気を掴むにはいいタスクだったと思います。

まとめ

本記事では、実装したTransformerモデルを使って英日翻訳タスクを解きました。
Part4ではGPT-2の実装を解説します。

ゼロから作るLLM - 目次

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?