0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

RNNがわからない理由を、PyTorchのコードで完全に言語化してみた

0
Last updated at Posted at 2026-01-06

はじめに

最近、LLMの仕組みを理解する必要に迫られ、改めてディープラーニングを学び直した時に感じたことがあります。

「CNN(畳み込みニューラルネットワーク)は理解できたが、RNN(リカレントニューラルネットワーク)がスッと理解できない」

結論として、RNNが理解しにくい最大の理由は、
「層が増えているように見えるが、実際は同じ重みをループで使い回している」
という点をコードレベルで認識できていないことです。

対象読者

  • 理論よりも先にコードを動かしてみたい方
  • ディープラーニングを学び直したい方
  • 時系列データや言語モデルに興味がある方
  • RNNの図は見たことがあるが、具体的な処理のイメージが湧いていない方

まずはコードを動かしてみる

今回は英語→日本語への翻訳をタスクと設定して、実際にコードを動かしてみます。
ネットワークはクラスとして実装し、Encoder-Decoderモデルを採用します。今回はRNN(Simple RNN)ではなく、より実用的なLSTM (Long Short-Term Memory) で実装しています。

データの準備

テストデータとして、以下のような形式の trainGen.csv を用意します。

trainGen.csv
//英文,日本語(わかち書き済み)
this is a good pen,これ は よ い ペン だ

英語と日本語のペアデータは、以下のデータセット(5万行の英文・日本語訳)を利用させていただきました。

実装コード


import csv
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

############################################
# 1. デバイス設定
############################################

if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

torch.set_float32_matmul_precision("high")
torch.manual_seed(42)

############################################
# 2. 設定
############################################

BATCH_SIZE = 64
EMB_SIZE = 256
HIDDEN_SIZE = 512
EPOCHS = 20
LR = 1e-3
MAX_LEN = 30

MODEL_PATH = "seq2seq_lstm.pt"

############################################
# 3. 語彙
############################################

PAD, SOS, EOS, UNK = "<pad>", "<sos>", "<eos>", "<unk>"

class Vocab:
    def __init__(self):
        self.word2idx = {PAD:0, SOS:1, EOS:2, UNK:3}
        self.idx2word = {0:PAD, 1:SOS, 2:EOS, 3:UNK}

    def add_sentence(self, tokens):
        for w in tokens:
            if w not in self.word2idx:
                idx = len(self.word2idx)
                self.word2idx[w] = idx
                self.idx2word[idx] = w

    def encode(self, tokens):
        return [self.word2idx.get(w, self.word2idx[UNK]) for w in tokens]

    def __len__(self):
        return len(self.word2idx)

############################################
# 4. Dataset
############################################

class TranslationDataset(Dataset):
    def __init__(self, csv_path):
        self.pairs = []
        self.src_vocab = Vocab()
        self.tgt_vocab = Vocab()

        with open(csv_path, encoding="utf-8") as f:
            reader = csv.reader(f)
            for row in reader:
                if len(row) < 2:
                    continue

                src = row[0].lower().split()
                tgt = row[1].split()

                if len(src) > MAX_LEN or len(tgt) > MAX_LEN:
                    continue

                self.src_vocab.add_sentence(src)
                self.tgt_vocab.add_sentence(tgt)
                self.pairs.append((src, tgt))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        src_ids = self.src_vocab.encode(src) + [self.src_vocab.word2idx[EOS]]
        tgt_ids = [self.tgt_vocab.word2idx[SOS]] + \
                  self.tgt_vocab.encode(tgt) + \
                  [self.tgt_vocab.word2idx[EOS]]
        return torch.tensor(src_ids), torch.tensor(tgt_ids)

def collate_fn(batch):
    srcs, tgts = zip(*batch)

    src_len = max(len(s) for s in srcs)
    tgt_len = max(len(t) for t in tgts)

    src_pad = torch.zeros(len(batch), src_len, dtype=torch.long)
    tgt_pad = torch.zeros(len(batch), tgt_len, dtype=torch.long)

    for i, (s, t) in enumerate(zip(srcs, tgts)):
        src_pad[i, :len(s)] = s
        tgt_pad[i, :len(t)] = t

    return src_pad.to(DEVICE), tgt_pad.to(DEVICE)

############################################
# 5. Encoder / Decoder
############################################

class Encoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMB_SIZE, padding_idx=0)
        self.lstm = nn.LSTM(EMB_SIZE, HIDDEN_SIZE, batch_first=True)
        # self.lstm = ManualFunctionLSTM(EMB_SIZE, HIDDEN_SIZE)

    def forward(self, x):
        emb = self.emb(x)
        _, (h, c) = self.lstm(emb)
        return h, c

class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMB_SIZE, padding_idx=0)
        self.lstm = nn.LSTM(EMB_SIZE, HIDDEN_SIZE, batch_first=True)
        # self.lstm = ManualFunctionLSTM(EMB_SIZE, HIDDEN_SIZE)
        self.fc = nn.Linear(HIDDEN_SIZE, vocab_size)

    def forward(self, x, state):
        emb = self.emb(x)
        out, state = self.lstm(emb, state)
        logits = self.fc(out)
        return logits, state

############################################
# 6. Seq2Seq
############################################

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, tgt):
        state = self.encoder(src)
        logits, _ = self.decoder(tgt[:, :-1], state)
        return logits

############################################
# 7. 学習 or ロード
############################################

def train():
    dataset = TranslationDataset("trainGen.csv")

    model = Seq2Seq(
        Encoder(len(dataset.src_vocab)),
        Decoder(len(dataset.tgt_vocab))
    ).to(DEVICE)

    if os.path.exists(MODEL_PATH):
        print("Loading trained model...")
        checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
        model.load_state_dict(checkpoint["model"])
        dataset.src_vocab.word2idx = checkpoint["src_word2idx"]
        dataset.src_vocab.idx2word = {
            v: k for k, v in checkpoint["src_word2idx"].items()
        }

        dataset.tgt_vocab.word2idx = checkpoint["tgt_word2idx"]
        dataset.tgt_vocab.idx2word = {
            v: k for k, v in checkpoint["tgt_word2idx"].items()
        }
        return model, dataset

    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn
    )

    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0

        for src, tgt in loader:
            optimizer.zero_grad()
            logits = model(src, tgt)

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                tgt[:, 1:].reshape(-1)
            )
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss {total_loss/len(loader):.4f}")

    torch.save({
        "model": model.state_dict(),
        "src_word2idx": dataset.src_vocab.word2idx,
        "tgt_word2idx": dataset.tgt_vocab.word2idx,
    }, MODEL_PATH)

    return model, dataset

############################################
# 8. 推論
############################################

@torch.no_grad()
def translate(model, dataset, sentence):
    model.eval()

    src_ids = dataset.src_vocab.encode(sentence.lower().split()) + \
              [dataset.src_vocab.word2idx[EOS]]
    src = torch.tensor(src_ids).unsqueeze(0).to(DEVICE)

    state = model.encoder(src)

    cur = torch.tensor([[dataset.tgt_vocab.word2idx[SOS]]]).to(DEVICE)
    result = []

    for _ in range(MAX_LEN):
        logits, state = model.decoder(cur, state)
        next_id = logits.argmax(-1).item()
        if next_id == dataset.tgt_vocab.word2idx[EOS]:
            break
        result.append(dataset.tgt_vocab.idx2word[next_id])
        cur = torch.tensor([[next_id]]).to(DEVICE)

    return " ".join(result)

############################################
# 9. 実行
############################################

if __name__ == "__main__":
    model, dataset = train()
    print(translate(model, dataset, "i want to be a baseball player"))
    print(translate(model, dataset, "he has a good pen"))
    print(translate(model, dataset, "i like music"))



出力

上記のファイルを実行すると以下のように出力されました。
何パターンかパラメータを試しましたが、上記の設定でもそれなりに精度が出ていました。
翻訳も違和感がないかと思います。

Epoch 1, Loss 3.3633
Epoch 2, Loss 2.3256
Epoch 3, Loss 1.7701
Epoch 4, Loss 1.3823
Epoch 5, Loss 1.0870
Epoch 6, Loss 0.8503
Epoch 7, Loss 0.6629
Epoch 8, Loss 0.5163
Epoch 9, Loss 0.4041
Epoch 10, Loss 0.3173
Epoch 11, Loss 0.2548
Epoch 12, Loss 0.2100
Epoch 13, Loss 0.1785
Epoch 14, Loss 0.1553
Epoch 15, Loss 0.1410
Epoch 16, Loss 0.1305
Epoch 17, Loss 0.1206
Epoch 18, Loss 0.1151
Epoch 19, Loss 0.1099
Epoch 20, Loss 0.1064
私 は 野球 選手 に な り た い
彼 は ペン を 持 っ て い る
私 は 音楽 が 好き だ

RNNの何が理解を阻むのか

体感ですが、ディープラーニングを学ぼうとして最初に手を出すのは「画像分類」であることが多い気がします。データの豊富さやわかりやすさから、オライリーの『ゼロから作るDeep Learning』などを通じてCNNから入るパターンです。

「CNNはすっと頭に入ったが、RNNは頭に入らない」「図の意味はなんとなくわかったが、何をしているのかイメージできない」という方は結構いるのではないでしょうか。
CNNから順に学んだ場合、一体何が理解を阻むのかを分析してみました。


1. CNNで学んだことがどう活かせるのかがわからない

一番最初に感じたのは、「CNNで学んだ技術のうち、どれがRNNに活かせるのかわからない」という点です。

結論から言うと、CNNで学んだコアな技術(誤差逆伝播法、活性化関数、最適化手法など)はそのまま使えます。

今回のコードを見ると、train() 関数内部での学習ループ自体は、CNNの時とほぼ同じ構造であることがわかります。

# train()内の一部抜粋

    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    for epoch in range(EPOCHS):
        # ... (中略) ...
        for src, tgt in loader:
            optimizer.zero_grad()
            logits = model(src, tgt)

            loss = criterion(...)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

Lossの定義などは共通の文法ですが、系列データ特有のテクニックとして 教師強制(Teacher Forcing) がある点には注意が必要です。

簡単に言うと、「学習時、Decoderへの入力として『1つ前の時刻の予測結果』を使うのではなく、『正解データ』を使う手法」 のことです。

学習初期にモデルの精度が低い状態で、間違った予測結果を入力にし続けると学習が全く進まないため、カンニングペーパー(正解)を見せながら学習させるイメージです。コード上では tgt(正解ラベル)をそのまま入力として渡している部分がそれに当たります。

また、言語モデルにおける正解は「次単語を正確に予測すること」なので、Lossの計算では「予測値 」と「正解データ 」を比較しています(コードでは tgt[:, 1:] と1つずらして比較しています)。


2. RNNとLSTMの違いがわからない

結論から言うと、Simple RNNとLSTMは実装上は別物です。
「RNN」という言葉が、時系列モデルの総称(Recurrent Neural Networks)として使われる場合と、特定の層(Simple RNN)を指す場合があるため、混同しやすいのだと思います。

pytorch.nn では明確に別のクラスとして実装されています。

今回のコードでは、勾配消失問題などに強い LSTM を利用しています。

# Encoder Class内

class Encoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, EMB_SIZE, padding_idx=0)
        # ここでLSTMを指定
        self.lstm = nn.LSTM(EMB_SIZE, HIDDEN_SIZE, batch_first=True)

nn.LSTM には EMB_SIZE(入力次元)と HIDDEN_SIZE(隠れ層の次元)の2つの引数を渡しています。HIDDEN_SIZE はモデルの表現力の高さ(記憶容量のようなもの)に直結します。

3. 図を見てみたが層や処理の流れのイメージがわかない

RNNの説明でよく見る以下のような図。これがパッと頭に入ってこない原因は、「層」の概念にあると思います。

CNNでは「層を深く積み重ねる」イメージが強いため、この図も「層が横に並んでいる」ように見えがちです。

しかし結論として、この図は複数の層があることを表しているのではなく、ループ処理で「全く同じユニット(重み)」を何度も使い回している(回帰している) ことを表しています。

これを「時間展開」と呼びます。
コード上の nn.LSTM も、内部ではループ処理を行っています。時刻 の計算も の計算も、使われる重みパラメータは同一です。これが「Recurrent(回帰する)」である所以です。


4. いざ翻訳をやってみたいが「言語」を扱う方法がわからない

画像データはピクセル値(数値)ですが、言語はそのままでは計算できません。
結論、各単語(トークン)は、演算が可能なN次元のベクトルに変換する必要があります。

これを扱うために、PyTorchでは nn.Embedding という層が用意されています。

self.emb = nn.Embedding(vocab_size, EMB_SIZE, padding_idx=0)

これは巨大なルックアップテーブルのようなものです。
「語彙数 × ベクトル次元数」の行列を持ち、単語IDを渡すと、その単語に対応するベクトルを返してくれます。最初はランダムな値ですが、学習を通して「意味の近い単語は近いベクトルになる」ように更新されていきます。


5. 図の意味はわかったが具体的なタスクを実装する方法がわからない

翻訳タスク(Seq2Seq)では、翻訳前言語を読み込むEncoderと、翻訳後言語を生成するDecoderの連携が必要です。ここで重要なのが、「RNN/LSTMが出力するデータ(隠れ状態)は何を表しているのか」という点です。

    def forward(self, src, tgt):
        # Encoderで入力文を「意味ベクトル(state)」に変換
        state = self.encoder(src)
        # その「意味」を使ってDecoderで翻訳文を生成
        logits, _ = self.decoder(tgt[:, :-1], state)
        return logits

Encoderは英文を最初から最後まで読み込み、そのすべての情報を**固定長のベクトル(state: 隠れ状態 とセル状態 )に圧縮します。
この state こそが、
「入力された文章の意味」**そのものです。

  • Encoderの役割: 文章を読み、その意味(文脈)を state として出力する。
  • Decoderの役割: 受け取った state(文脈)と、開始合図の <sos> を元に、最初の単語を予測し、次々と単語を出力。

推論時のコード(translate 関数)を見ると、この流れがよくわかります。

    # 1. Encoderで英文を「意味(state)」に変換
    state = model.encoder(src)

    # 2. 最初の入力は <sos>
    cur = torch.tensor([[dataset.tgt_vocab.word2idx[SOS]]]).to(DEVICE)
    
    # 3. ループで1単語ずつ生成
    for _ in range(MAX_LEN):
        # 前の単語と、維持している「文脈(state)」を渡す
        logits, state = model.decoder(cur, state)
        # ...(次単語の決定処理)

おわりに

今回は、CNNの知識をベースに、RNN/LSTMのつまづきポイントと翻訳タスクの実装について解説しました。
LSTMを使うことで、可変長の入出力を扱うモデルが動くようになりました。

次回は、このモデルをベースにAttentionメカニズムへの拡張に繋げていきたいと思います。

0
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?