はじめに
以前の記事で、公式のチュートリアルにそってPyTorchでニューラルネットワークを作成しました。
BERTの実装を最終的な目標としていますが、BERTと同じAttentionベースのモデルであるTransformerのチュートリアルがPyTorchの公式にあったので、今回はこれにそってTransformerを作成してみます。
PyTorchによるTransformerの作成
今回は、Transformerに、途中で切れた文が与えられた時に、次にくる単語が何であるかを推測するタスクでTransformerの学習を行います。
環境作成
ここではMac OSでpipを使った場合の環境作成方法を説明します(使用したOSはMac OS 12.3.1)。
その他の場合は、こちらを参考に環境を構築してください。
(1) Homebrewでpython3をインストール
$ brew install python3
(2) pipを使ってPyTorchと今回利用するデータセットをインストール(チュートリアルではtorchtextのインストールに触れられていませんが、エラーになってしまったので追加でインストールしました)
$ pip3 install torch torchdata torchtext
なお、Google Colaboratoryなどのクラウドサービスを使えば、GPUを簡単に利用することができます。
Transformerモデルの作成
まず、下図に示すTransformerモデルを作成していきます。
Input Embedding
nn.Embedding
モジュールを使って、入力のOne-hotベクトルで表現されている単語を、埋め込みベクトルに変換します。
self.encoder = nn.Embedding(ntoken, d_model)
src = self.encoder(src) * math.sqrt(self.d_model)
Positional Encoding
Transformerでは、単語の位置情報を埋め込みベクトルに付加する必要があります。
そのために、下記のようなモジュールを作成します。
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
このモジュールを使って、埋め込みベクトルに位置情報を追加します。
self.pos_encoder = PositionalEncoding(d_model, dropout)
src = self.pos_encoder(src)
Multi-Head Attention + Feed Forward
PyTorchでは、図のグレーの部分(Multi-Head Attention + Feed Forward)の層がnn.TransformerEncoderLayer
として提供されています。
また、Transformerではグレー部分の層を複数繰り返してつなげます。この繰り返しのためにnn.TransformerEncoder
が用意されています。
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
output = self.transformer_encoder(src, src_mask)
Linear
nn.Linear
モジュールで、結果のベクトルを全語彙に該当するベクトルにマッピングします。
この結果のうち最も値の大きい単語が、入力された文(途中で切れている)の次にくる単語となります。
self.decoder = nn.Linear(d_model, ntoken)
output = self.decoder(output)
以下が、作成したTransformerのモデルです。
__init__
メソッドでネットワーク構造を定義し、forwardメソッドで入力データに対する処理を実装します。
forward
メソッドの戻り値がネットワークの出力となります。
import math
from typing import Tuple
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
class TransformerModel(nn.Module):
def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
nlayers: int, dropout: float = 0.5):
super().__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.decoder = nn.Linear(d_model, ntoken)
self.init_weights()
def init_weights(self) -> None:
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
"""
Args:
src: Transformerへの入力データ
src_mask: 入力データにかけるマスク
Returns:
Transformerの出力
"""
src = self.encoder(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output
データセットの読み込み
今回はデータセットとして、WikiText2を使います。
WikiText2は、環境構築で既にダウンロード済みです。
テキストのOne-hotベクトル化
data_process
メソッドで、WikiText2内の各単語を1次元のOne-hotベクトルに変換します。
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])
def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
# train_iterはvocabの構築で消費されるため、ここでもう一度作成する
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ミニバッチ作成
並列処理を可能にするために、batchify
メソッドを使って、入力データをいくつかのグループ(ミニバッチ)に分割します。グループの数は引数のbsz
で指定します。
例えば、data
として下記のような26文字のアルファベットからなるベクトル、
\left[
\begin{array}{ccccccc}
A & B & C & \ldots & X & Y & Z
\end{array}
\right]
bsz
(バッチサイズ)を4とした場合、出力は下記のようになります。
\begin{eqnarray}
\left[
\begin{array}{cccc}
\left[
\begin{array}{c}
A\\B\\C\\D\\E\\F
\end{array}
\right]&
\left[
\begin{array}{c}
G\\H\\I\\J\\K\\L
\end{array}
\right]&
\left[
\begin{array}{c}
M\\N\\O\\P\\Q\\R
\end{array}
\right]&
\left[
\begin{array}{c}
S\\T\\U\\V\\W\\X
\end{array}
\right]
\end{array}
\right]
\end{eqnarray}
def batchify(data: Tensor, bsz: int) -> Tensor:
"""
Args:
data: 入力データ
bsz: バッチサイズ
Returns:
ミニバッチ
"""
seq_len = data.size(0) // bsz
data = data[:seq_len * bsz]
data = data.view(bsz, seq_len).t().contiguous()
return data.to(device)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size) # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
学習用データ、正解データの作成
get_batch
メソッドを使って、学習用データ(または検証用データ)と正解データを作成します。
正解データは、学習用データの単語列を、1単語前にずらした単語列となります。
先程のアルファベットのミニバッチを入力とし、i
を0、bptt
を2とした場合、get_batchの出力は以下のようになります(左が学習用データ、右が訓練用データ)。
\begin{eqnarray}
\left(
\begin{array}{ccc}
\left[
\begin{array}{c}
\left[
\begin{array}{cccc}
A & G & M & S
\end{array}
\right]\\
\left[
\begin{array}{cccc}
B & H & N & T
\end{array}
\right]
\end{array}
\right]
& , &
\left[
\begin{array}{c}
\left[
\begin{array}{cccc}
B & H & N & T
\end{array}
\right]\\
\left[
\begin{array}{cccc}
C & I & O & U
\end{array}
\right]
\end{array}
\right]
\end{array}
\right)
\end{eqnarray}
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
"""
Args:
source: ミニバッチ
i: int
Returns:
学習用データ、正解データのtuple (data, target)
"""
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].reshape(-1)
return data, target
マスキング
generate_square_subsequent_mask
を使って、Transformerが学習を行う際に、現在注目している単語に後続する単語を参照できなくします。
def generate_square_subsequent_mask(sz: int) -> Tensor:
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
学習と検証
ここから、Transformerを使った学習と検証を説明します。大まかな流れは、以前作成したニューラルネットワークの時と同じです。
損失関数
Transformerの学習では、出力と正解との誤差(損失関数)を計算し、損失関数の値が小さくなるようにしていきます。
今回は、多クラス分類の学習であるため交差エントロピー誤差(cross entropy error)を損失関数として使います。PyTorchではnn.CrossEntropyLoss
として提供されています。
criterion = nn.CrossEntropyLoss()
最適化
最適化とは、先ほど説明した損失関数の値が小さくなるように、Transformerのパラメーター(重み、バイアス)を調整することです。パラメーターの調整量のことを勾配(gradient)と呼びます。
今回は最適化アルゴリズムの1つである確率的勾配降下法(stochastic gradient descent, SDG)を使用します。PyTorchではtorch.optim.SGD
として提供されています。
torch.optim.SGD
には、モデルのパラメーターと学習係数を指定します。学習係数によりパラメーターの更新量を調整することができます。
lr = 5.0 # 学習係数
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
学習
ミニバッチ単位で以下の手順を実行し、学習を行います。
(1) Transformerに学習用データを入力し、出力を得る。
# data: Transformerへの入力、src_mask: dataにかけるマスク、model: Transformerモデル、output: Transformerからの出力
output = model(data, src_mask)
(2) 出力と正解から損失関数を計算する。
# output: Transformerからの出力、targets: 正解、criterion: 損失関数
loss = criterion(output.view(-1, ntokens), targets)
(3) 勾配の値をリセットする(0にする)。
optimizer.zero_grad()
(4) 損失関数から誤差逆伝播法(back propagation)により、Transformerの全パラメーターの勾配を計算する。勾配計算はPyTorch組み込みの微分エンジンtorch.autograd
により行われています。詳細を知りたい方はこちらをご覧ください。
loss.backward()
(5) 計算した勾配を用いて、全パラメーターの値を更新する
optimizer.step()
以下は、上記の手順をまとめたメソッドです。
import copy
import time
ntokens = len(vocab) # 語彙数
emsize = 200 # 埋め込みベクトルの次元
d_hid = 200 # nn.TransformerEncoderのフィードフォワードネットワークの次元
nlayers = 2 # nn.TransformerEncoder内のnn.TransformerEncoderLayerの数
nhead = 2 # nn.MultiheadAttention内のヘッドの数
dropout = 0.2 # dropoutの割合
model = TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout).to(device)
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
def train(model: nn.Module) -> None:
model.train() # 学習モードに切り替え
total_loss = 0.
log_interval = 200
start_time = time.time()
src_mask = generate_square_subsequent_mask(bptt).to(device)
num_batches = len(train_data) // bptt
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i)
batch_size = data.size(0)
if batch_size != bptt: # 最後のバッチの時のみ
src_mask = src_mask[:batch_size, :batch_size]
output = model(data, src_mask)
loss = criterion(output.view(-1, ntokens), targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
total_loss += loss.item()
if batch % log_interval == 0 and batch > 0:
lr = scheduler.get_last_lr()[0]
ms_per_batch = (time.time() - start_time) * 1000 / log_interval
cur_loss = total_loss / log_interval
ppl = math.exp(cur_loss)
print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
total_loss = 0
start_time = time.time()
検証
検証では、検証データをTransformerに入力し、得られた出力と正解との誤差を計算します。
検証では学習が不要なため、torch.no_grad()
によって勾配計算に必要な処理を無効にします(処理性能向上のため)。
def evaluate(model: nn.Module, eval_data: Tensor) -> float:
model.eval() # 検証モードに切り替え
total_loss = 0.
src_mask = generate_square_subsequent_mask(bptt).to(device)
with torch.no_grad():
for i in range(0, eval_data.size(0) - 1, bptt):
data, targets = get_batch(eval_data, i)
batch_size = data.size(0)
if batch_size != bptt:
src_mask = src_mask[:batch_size, :batch_size]
output = model(data, src_mask)
output_flat = output.view(-1, ntokens)
total_loss += batch_size * criterion(output_flat, targets).item()
return total_loss / (len(eval_data) - 1)
プログラムの実行
最後に、下記のプログラムで3エポックの学習(train)+検証(evaluate)を繰り返します。
best_val_loss = float('inf')
#エポック数
epochs = 3
best_model = None
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train(model)
val_loss = evaluate(model, val_data)
val_ppl = math.exp(val_loss)
elapsed = time.time() - epoch_start_time
print('-' * 89)
print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
print('-' * 89)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = copy.deepcopy(model)
scheduler.step()
その際の出力結果がこちらです。学習が進むにつれ正解率(Accuracy)が上昇し、誤差(loss)が小さくなっていることが確認できます。
ちなみに、GPUは使っていないため、3エポックでもかなりの時間がかかりました。実際に試してみる方は、Google Colaboratoryでの実行をおすすめします。
| epoch 1 | 200/ 2928 batches | lr 5.00 | ms/batch 420.44 | loss 8.01 | ppl 3018.63
| epoch 1 | 400/ 2928 batches | lr 5.00 | ms/batch 434.12 | loss 6.86 | ppl 950.57
| epoch 1 | 600/ 2928 batches | lr 5.00 | ms/batch 440.21 | loss 6.42 | ppl 615.38
| epoch 1 | 800/ 2928 batches | lr 5.00 | ms/batch 436.10 | loss 6.29 | ppl 540.35
| epoch 1 | 1000/ 2928 batches | lr 5.00 | ms/batch 440.15 | loss 6.18 | ppl 480.69
| epoch 1 | 1200/ 2928 batches | lr 5.00 | ms/batch 430.71 | loss 6.15 | ppl 468.42
| epoch 1 | 1400/ 2928 batches | lr 5.00 | ms/batch 424.53 | loss 6.11 | ppl 449.43
| epoch 1 | 1600/ 2928 batches | lr 5.00 | ms/batch 426.39 | loss 6.10 | ppl 445.31
| epoch 1 | 1800/ 2928 batches | lr 5.00 | ms/batch 422.44 | loss 6.02 | ppl 411.21
| epoch 1 | 2000/ 2928 batches | lr 5.00 | ms/batch 412.33 | loss 6.01 | ppl 407.06
| epoch 1 | 2200/ 2928 batches | lr 5.00 | ms/batch 417.48 | loss 5.89 | ppl 359.91
| epoch 1 | 2400/ 2928 batches | lr 5.00 | ms/batch 430.99 | loss 5.96 | ppl 389.16
| epoch 1 | 2600/ 2928 batches | lr 5.00 | ms/batch 440.37 | loss 5.95 | ppl 383.34
| epoch 1 | 2800/ 2928 batches | lr 5.00 | ms/batch 433.94 | loss 5.88 | ppl 357.36
-----------------------------------------------------------------------------------------
| end of epoch 1 | time: 1303.18s | valid loss 5.81 | valid ppl 334.50
-----------------------------------------------------------------------------------------
| epoch 2 | 200/ 2928 batches | lr 4.75 | ms/batch 448.69 | loss 5.85 | ppl 347.13
| epoch 2 | 400/ 2928 batches | lr 4.75 | ms/batch 451.74 | loss 5.84 | ppl 344.30
| epoch 2 | 600/ 2928 batches | lr 4.75 | ms/batch 442.30 | loss 5.65 | ppl 285.24
| epoch 2 | 800/ 2928 batches | lr 4.75 | ms/batch 450.91 | loss 5.69 | ppl 296.31
| epoch 2 | 1000/ 2928 batches | lr 4.75 | ms/batch 461.74 | loss 5.64 | ppl 282.26
| epoch 2 | 1200/ 2928 batches | lr 4.75 | ms/batch 464.20 | loss 5.67 | ppl 290.53
| epoch 2 | 1400/ 2928 batches | lr 4.75 | ms/batch 433.12 | loss 5.68 | ppl 293.03
| epoch 2 | 1600/ 2928 batches | lr 4.75 | ms/batch 433.08 | loss 5.70 | ppl 299.90
| epoch 2 | 1800/ 2928 batches | lr 4.75 | ms/batch 446.36 | loss 5.64 | ppl 281.13
| epoch 2 | 2000/ 2928 batches | lr 4.75 | ms/batch 443.56 | loss 5.66 | ppl 286.80
| epoch 2 | 2200/ 2928 batches | lr 4.75 | ms/batch 445.23 | loss 5.54 | ppl 255.44
| epoch 2 | 2400/ 2928 batches | lr 4.75 | ms/batch 454.44 | loss 5.63 | ppl 279.42
| epoch 2 | 2600/ 2928 batches | lr 4.75 | ms/batch 451.50 | loss 5.63 | ppl 279.43
| epoch 2 | 2800/ 2928 batches | lr 4.75 | ms/batch 456.88 | loss 5.57 | ppl 262.01
-----------------------------------------------------------------------------------------
| end of epoch 2 | time: 1361.26s | valid loss 5.65 | valid ppl 284.21
-----------------------------------------------------------------------------------------
| epoch 3 | 200/ 2928 batches | lr 4.51 | ms/batch 443.37 | loss 5.59 | ppl 268.65
| epoch 3 | 400/ 2928 batches | lr 4.51 | ms/batch 476.49 | loss 5.61 | ppl 272.55
| epoch 3 | 600/ 2928 batches | lr 4.51 | ms/batch 466.81 | loss 5.41 | ppl 222.53
| epoch 3 | 800/ 2928 batches | lr 4.51 | ms/batch 447.77 | loss 5.47 | ppl 236.85
| epoch 3 | 1000/ 2928 batches | lr 4.51 | ms/batch 447.26 | loss 5.42 | ppl 225.61
| epoch 3 | 1200/ 2928 batches | lr 4.51 | ms/batch 444.40 | loss 5.46 | ppl 235.93
| epoch 3 | 1400/ 2928 batches | lr 4.51 | ms/batch 461.98 | loss 5.49 | ppl 241.44
| epoch 3 | 1600/ 2928 batches | lr 4.51 | ms/batch 433.08 | loss 5.52 | ppl 249.04
| epoch 3 | 1800/ 2928 batches | lr 4.51 | ms/batch 426.84 | loss 5.46 | ppl 234.91
| epoch 3 | 2000/ 2928 batches | lr 4.51 | ms/batch 426.60 | loss 5.48 | ppl 238.70
| epoch 3 | 2200/ 2928 batches | lr 4.51 | ms/batch 5334.73 | loss 5.35 | ppl 210.36
| epoch 3 | 2400/ 2928 batches | lr 4.51 | ms/batch 450.23 | loss 5.46 | ppl 234.02
| epoch 3 | 2600/ 2928 batches | lr 4.51 | ms/batch 453.25 | loss 5.47 | ppl 238.04
| epoch 3 | 2800/ 2928 batches | lr 4.51 | ms/batch 468.60 | loss 5.39 | ppl 219.83
-----------------------------------------------------------------------------------------
| end of epoch 3 | time: 2341.51s | valid loss 5.60 | valid ppl 269.74
-----------------------------------------------------------------------------------------
まとめ
前回のPyTorchを使ったニューラルネットワークの作成に続き、今回はTransformerモデルの作成をしましたが、入力データ、モデルの作成以外は、ほぼ同じ実装なことに気づきました。
PyTorchでのCNNやRNNは未経験なので同じパターンかどうかわかりませんが、このパターンは覚えておいて損はなさそうです。
次回からは、これまで学習してきた内容をもとに、PyTorchを使ってBERTの実装をしてみようと思います。
今回作成したプログラムはGitHubで公開しています。
https://github.com/age884/qiita_transformer/blob/main/transformer.py