LoginSignup
19
11

More than 1 year has passed since last update.

PyTorchでTransformer

Posted at

はじめに

以前の記事で、公式のチュートリアルにそってPyTorchでニューラルネットワークを作成しました。

BERTの実装を最終的な目標としていますが、BERTと同じAttentionベースのモデルであるTransformerのチュートリアルがPyTorchの公式にあったので、今回はこれにそってTransformerを作成してみます。

PyTorchによるTransformerの作成

今回は、Transformerに、途中で切れた文が与えられた時に、次にくる単語が何であるかを推測するタスクでTransformerの学習を行います。

環境作成

ここではMac OSでpipを使った場合の環境作成方法を説明します(使用したOSはMac OS 12.3.1)。
その他の場合は、こちらを参考に環境を構築してください。

(1) Homebrewでpython3をインストール

$ brew install python3

(2) pipを使ってPyTorchと今回利用するデータセットをインストール(チュートリアルではtorchtextのインストールに触れられていませんが、エラーになってしまったので追加でインストールしました)

$ pip3 install torch torchdata torchtext

なお、Google Colaboratoryなどのクラウドサービスを使えば、GPUを簡単に利用することができます。

Transformerモデルの作成

まず、下図に示すTransformerモデルを作成していきます。

transformer.png

Input Embedding

nn.Embeddingモジュールを使って、入力のOne-hotベクトルで表現されている単語を、埋め込みベクトルに変換します。

self.encoder = nn.Embedding(ntoken, d_model)
src = self.encoder(src) * math.sqrt(self.d_model)

Positional Encoding

Transformerでは、単語の位置情報を埋め込みベクトルに付加する必要があります。
そのために、下記のようなモジュールを作成します。

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

このモジュールを使って、埋め込みベクトルに位置情報を追加します。

self.pos_encoder = PositionalEncoding(d_model, dropout)
src = self.pos_encoder(src)

Multi-Head Attention + Feed Forward

PyTorchでは、図のグレーの部分(Multi-Head Attention + Feed Forward)の層がnn.TransformerEncoderLayerとして提供されています。
また、Transformerではグレー部分の層を複数繰り返してつなげます。この繰り返しのためにnn.TransformerEncoderが用意されています。

encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
output = self.transformer_encoder(src, src_mask)

Linear

nn.Linearモジュールで、結果のベクトルを全語彙に該当するベクトルにマッピングします。
この結果のうち最も値の大きい単語が、入力された文(途中で切れている)の次にくる単語となります。

self.decoder = nn.Linear(d_model, ntoken)
output = self.decoder(output)

以下が、作成したTransformerのモデルです。
__init__メソッドでネットワーク構造を定義し、forwardメソッドで入力データに対する処理を実装します。
forwardメソッドの戻り値がネットワークの出力となります。

import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """
        Args:
            src: Transformerへの入力データ
            src_mask: 入力データにかけるマスク
        Returns:
            Transformerの出力
        """
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output

データセットの読み込み

今回はデータセットとして、WikiText2を使います。
WikiText2は、環境構築で既にダウンロード済みです。

テキストのOne-hotベクトル化

data_processメソッドで、WikiText2内の各単語を1次元のOne-hotベクトルに変換します。

from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# train_iterはvocabの構築で消費されるため、ここでもう一度作成する
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ミニバッチ作成

並列処理を可能にするために、batchifyメソッドを使って、入力データをいくつかのグループ(ミニバッチ)に分割します。グループの数は引数のbszで指定します。

例えば、dataとして下記のような26文字のアルファベットからなるベクトル、

\left[
\begin{array}{ccccccc}
A & B & C & \ldots & X & Y & Z
\end{array}
\right]

bsz(バッチサイズ)を4とした場合、出力は下記のようになります。

\begin{eqnarray}
\left[
\begin{array}{cccc}
\left[
\begin{array}{c}
A\\B\\C\\D\\E\\F
\end{array}
\right]&
\left[
\begin{array}{c}
G\\H\\I\\J\\K\\L
\end{array}
\right]&
\left[
\begin{array}{c}
M\\N\\O\\P\\Q\\R
\end{array}
\right]&
\left[
\begin{array}{c}
S\\T\\U\\V\\W\\X
\end{array}
\right]
\end{array}
\right]
\end{eqnarray}
def batchify(data: Tensor, bsz: int) -> Tensor:
    """
    Args:
        data: 入力データ
        bsz: バッチサイズ

    Returns:
        ミニバッチ
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

学習用データ、正解データの作成

get_batchメソッドを使って、学習用データ(または検証用データ)と正解データを作成します。
正解データは、学習用データの単語列を、1単語前にずらした単語列となります。

先程のアルファベットのミニバッチを入力とし、iを0、bpttを2とした場合、get_batchの出力は以下のようになります(左が学習用データ、右が訓練用データ)。

\begin{eqnarray}
\left(
\begin{array}{ccc}
\left[
\begin{array}{c}
\left[
\begin{array}{cccc}
A & G & M & S
\end{array}
\right]\\
\left[
\begin{array}{cccc}
B & H & N & T
\end{array}
\right]
\end{array}
\right]
 & , &
\left[
\begin{array}{c}
\left[
\begin{array}{cccc}
B & H & N & T
\end{array}
\right]\\
\left[
\begin{array}{cccc}
C & I & O & U
\end{array}
\right]
\end{array}
\right]
\end{array}
\right)
\end{eqnarray}
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: ミニバッチ
        i: int

    Returns:
        学習用データ、正解データのtuple (data, target)
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

マスキング

generate_square_subsequent_maskを使って、Transformerが学習を行う際に、現在注目している単語に後続する単語を参照できなくします。

def generate_square_subsequent_mask(sz: int) -> Tensor:
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

学習と検証

ここから、Transformerを使った学習と検証を説明します。大まかな流れは、以前作成したニューラルネットワークの時と同じです。

損失関数

Transformerの学習では、出力と正解との誤差(損失関数)を計算し、損失関数の値が小さくなるようにしていきます。
今回は、多クラス分類の学習であるため交差エントロピー誤差(cross entropy error)を損失関数として使います。PyTorchではnn.CrossEntropyLossとして提供されています。

criterion = nn.CrossEntropyLoss()

最適化

最適化とは、先ほど説明した損失関数の値が小さくなるように、Transformerのパラメーター(重み、バイアス)を調整することです。パラメーターの調整量のことを勾配(gradient)と呼びます。
今回は最適化アルゴリズムの1つである確率的勾配降下法(stochastic gradient descent, SDG)を使用します。PyTorchではtorch.optim.SGDとして提供されています。

torch.optim.SGDには、モデルのパラメーターと学習係数を指定します。学習係数によりパラメーターの更新量を調整することができます。

lr = 5.0  # 学習係数
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

学習

ミニバッチ単位で以下の手順を実行し、学習を行います。

(1) Transformerに学習用データを入力し、出力を得る。

# data: Transformerへの入力、src_mask: dataにかけるマスク、model: Transformerモデル、output: Transformerからの出力
output = model(data, src_mask)

(2) 出力と正解から損失関数を計算する。

# output: Transformerからの出力、targets: 正解、criterion: 損失関数
loss = criterion(output.view(-1, ntokens), targets)

(3) 勾配の値をリセットする(0にする)。

optimizer.zero_grad()

(4) 損失関数から誤差逆伝播法(back propagation)により、Transformerの全パラメーターの勾配を計算する。勾配計算はPyTorch組み込みの微分エンジンtorch.autogradにより行われています。詳細を知りたい方はこちらをご覧ください。

loss.backward()

(5) 計算した勾配を用いて、全パラメーターの値を更新する

optimizer.step()

以下は、上記の手順をまとめたメソッドです。

import copy
import time

ntokens = len(vocab)  # 語彙数
emsize = 200  # 埋め込みベクトルの次元
d_hid = 200  # nn.TransformerEncoderのフィードフォワードネットワークの次元
nlayers = 2  # nn.TransformerEncoder内のnn.TransformerEncoderLayerの数
nhead = 2  # nn.MultiheadAttention内のヘッドの数
dropout = 0.2  # dropoutの割合
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)

criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train(model: nn.Module) -> None:
    model.train()  # 学習モードに切り替え
    total_loss = 0.
    log_interval = 200
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(bptt).to(device)

    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        batch_size = data.size(0)
        if batch_size != bptt:  # 最後のバッチの時のみ
            src_mask = src_mask[:batch_size, :batch_size]
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

検証

検証では、検証データをTransformerに入力し、得られた出力と正解との誤差を計算します。
検証では学習が不要なため、torch.no_grad()によって勾配計算に必要な処理を無効にします(処理性能向上のため)。

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # 検証モードに切り替え
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            batch_size = data.size(0)
            if batch_size != bptt:
                src_mask = src_mask[:batch_size, :batch_size]
            output = model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += batch_size * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

プログラムの実行

最後に、下記のプログラムで3エポックの学習(train)+検証(evaluate)を繰り返します。

best_val_loss = float('inf')
#エポック数
epochs = 3
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)
    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)

    scheduler.step()

その際の出力結果がこちらです。学習が進むにつれ正解率(Accuracy)が上昇し、誤差(loss)が小さくなっていることが確認できます。
ちなみに、GPUは使っていないため、3エポックでもかなりの時間がかかりました。実際に試してみる方は、Google Colaboratoryでの実行をおすすめします。

| epoch   1 |   200/ 2928 batches | lr 5.00 | ms/batch 420.44 | loss  8.01 | ppl  3018.63
| epoch   1 |   400/ 2928 batches | lr 5.00 | ms/batch 434.12 | loss  6.86 | ppl   950.57
| epoch   1 |   600/ 2928 batches | lr 5.00 | ms/batch 440.21 | loss  6.42 | ppl   615.38
| epoch   1 |   800/ 2928 batches | lr 5.00 | ms/batch 436.10 | loss  6.29 | ppl   540.35
| epoch   1 |  1000/ 2928 batches | lr 5.00 | ms/batch 440.15 | loss  6.18 | ppl   480.69
| epoch   1 |  1200/ 2928 batches | lr 5.00 | ms/batch 430.71 | loss  6.15 | ppl   468.42
| epoch   1 |  1400/ 2928 batches | lr 5.00 | ms/batch 424.53 | loss  6.11 | ppl   449.43
| epoch   1 |  1600/ 2928 batches | lr 5.00 | ms/batch 426.39 | loss  6.10 | ppl   445.31
| epoch   1 |  1800/ 2928 batches | lr 5.00 | ms/batch 422.44 | loss  6.02 | ppl   411.21
| epoch   1 |  2000/ 2928 batches | lr 5.00 | ms/batch 412.33 | loss  6.01 | ppl   407.06
| epoch   1 |  2200/ 2928 batches | lr 5.00 | ms/batch 417.48 | loss  5.89 | ppl   359.91
| epoch   1 |  2400/ 2928 batches | lr 5.00 | ms/batch 430.99 | loss  5.96 | ppl   389.16
| epoch   1 |  2600/ 2928 batches | lr 5.00 | ms/batch 440.37 | loss  5.95 | ppl   383.34
| epoch   1 |  2800/ 2928 batches | lr 5.00 | ms/batch 433.94 | loss  5.88 | ppl   357.36
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 1303.18s | valid loss  5.81 | valid ppl   334.50
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 2928 batches | lr 4.75 | ms/batch 448.69 | loss  5.85 | ppl   347.13
| epoch   2 |   400/ 2928 batches | lr 4.75 | ms/batch 451.74 | loss  5.84 | ppl   344.30
| epoch   2 |   600/ 2928 batches | lr 4.75 | ms/batch 442.30 | loss  5.65 | ppl   285.24
| epoch   2 |   800/ 2928 batches | lr 4.75 | ms/batch 450.91 | loss  5.69 | ppl   296.31
| epoch   2 |  1000/ 2928 batches | lr 4.75 | ms/batch 461.74 | loss  5.64 | ppl   282.26
| epoch   2 |  1200/ 2928 batches | lr 4.75 | ms/batch 464.20 | loss  5.67 | ppl   290.53
| epoch   2 |  1400/ 2928 batches | lr 4.75 | ms/batch 433.12 | loss  5.68 | ppl   293.03
| epoch   2 |  1600/ 2928 batches | lr 4.75 | ms/batch 433.08 | loss  5.70 | ppl   299.90
| epoch   2 |  1800/ 2928 batches | lr 4.75 | ms/batch 446.36 | loss  5.64 | ppl   281.13
| epoch   2 |  2000/ 2928 batches | lr 4.75 | ms/batch 443.56 | loss  5.66 | ppl   286.80
| epoch   2 |  2200/ 2928 batches | lr 4.75 | ms/batch 445.23 | loss  5.54 | ppl   255.44
| epoch   2 |  2400/ 2928 batches | lr 4.75 | ms/batch 454.44 | loss  5.63 | ppl   279.42
| epoch   2 |  2600/ 2928 batches | lr 4.75 | ms/batch 451.50 | loss  5.63 | ppl   279.43
| epoch   2 |  2800/ 2928 batches | lr 4.75 | ms/batch 456.88 | loss  5.57 | ppl   262.01
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 1361.26s | valid loss  5.65 | valid ppl   284.21
-----------------------------------------------------------------------------------------
| epoch   3 |   200/ 2928 batches | lr 4.51 | ms/batch 443.37 | loss  5.59 | ppl   268.65
| epoch   3 |   400/ 2928 batches | lr 4.51 | ms/batch 476.49 | loss  5.61 | ppl   272.55
| epoch   3 |   600/ 2928 batches | lr 4.51 | ms/batch 466.81 | loss  5.41 | ppl   222.53
| epoch   3 |   800/ 2928 batches | lr 4.51 | ms/batch 447.77 | loss  5.47 | ppl   236.85
| epoch   3 |  1000/ 2928 batches | lr 4.51 | ms/batch 447.26 | loss  5.42 | ppl   225.61
| epoch   3 |  1200/ 2928 batches | lr 4.51 | ms/batch 444.40 | loss  5.46 | ppl   235.93
| epoch   3 |  1400/ 2928 batches | lr 4.51 | ms/batch 461.98 | loss  5.49 | ppl   241.44
| epoch   3 |  1600/ 2928 batches | lr 4.51 | ms/batch 433.08 | loss  5.52 | ppl   249.04
| epoch   3 |  1800/ 2928 batches | lr 4.51 | ms/batch 426.84 | loss  5.46 | ppl   234.91
| epoch   3 |  2000/ 2928 batches | lr 4.51 | ms/batch 426.60 | loss  5.48 | ppl   238.70
| epoch   3 |  2200/ 2928 batches | lr 4.51 | ms/batch 5334.73 | loss  5.35 | ppl   210.36
| epoch   3 |  2400/ 2928 batches | lr 4.51 | ms/batch 450.23 | loss  5.46 | ppl   234.02
| epoch   3 |  2600/ 2928 batches | lr 4.51 | ms/batch 453.25 | loss  5.47 | ppl   238.04
| epoch   3 |  2800/ 2928 batches | lr 4.51 | ms/batch 468.60 | loss  5.39 | ppl   219.83
-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 2341.51s | valid loss  5.60 | valid ppl   269.74
-----------------------------------------------------------------------------------------

まとめ

前回のPyTorchを使ったニューラルネットワークの作成に続き、今回はTransformerモデルの作成をしましたが、入力データ、モデルの作成以外は、ほぼ同じ実装なことに気づきました。
PyTorchでのCNNやRNNは未経験なので同じパターンかどうかわかりませんが、このパターンは覚えておいて損はなさそうです。

次回からは、これまで学習してきた内容をもとに、PyTorchを使ってBERTの実装をしてみようと思います。

今回作成したプログラムはGitHubで公開しています。
https://github.com/age884/qiita_transformer/blob/main/transformer.py

参考文献

19
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
11