6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

機械翻訳(Transformer)

Last updated at Posted at 2022-01-07

機械翻訳の実装をTransformerを用いて実装します!

###必要なモジュールを読み込み

!pip install -q pytorch_lightning
!pip install -q torchtext==0.11.0
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchtext

###分かち書き(spaCy)
英文の分かち書きspaCyというライブラリを使います。

!apt install aptitude swig
!aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y
!pip install mecab-python3==0.996.5
!pip install unidic-lite
!pip install -q fugashi

spaCy を用いて、分かち書きの準備をします。
.blank メソッドに、en や ja などを指定することで、spaCy 上で用意されているモデルを使用することができます。

import spacy

JA = spacy.blank('ja')
EN = spacy.blank('en')

日本語と英語のそれぞれ分かち書きする、関数を定義

def tokenize_ja(sentence):
    return [tok.text for tok in JA.tokenizer(sentence)]

def tokenize_en(sentence):
    return [tok.text for tok in EN.tokenizer(sentence)]

挙動を確認します。

tokenize_ja('月の宴')

['月', 'の', '宴']

tokenize_en('Tsuki no Utage (party of the moon)')

['Tsuki', 'no', 'Utage', '(', 'party', 'of', 'the', 'moon', ')']

###辞書の作成

yield_tokens という名前で、データフレームからトークン化した文字列を返す関数を定義します。

def yield_tokens(df, tokenize):
        for line in df:
            yield tokenize(line)

辞書を作成します。

from torchtext.vocab import build_vocab_from_iterator

vocab_ja = build_vocab_from_iterator(
    yield_tokens(df_train['Japanese'], tokenize_ja),
    specials=('<unk>', '<pad>', '<bos>', '<eos>'),
    special_first=True)

vocab_en = build_vocab_from_iterator(
    yield_tokens(df_train['English'], tokenize_en),
    specials=('<unk>', '<pad>', '<bos>', '<eos>'),
    special_first=True)

辞書の長さを確認します。

print(len(vocab_ja))
print(len(vocab_en))

22801
30664

文字列のインデックスの置き換え

transform_ja = lambda x: vocab_ja(tokenize_ja(x))
transform_en = lambda x: [vocab_en['<BOS>']] + vocab_en(tokenize_en(x)) + [vocab_en['<EOS>']]

置き換えの準備ができたので、パッディングを含めてインデックスへの置き換えを行います。

from torch.nn.utils.rnn import pad_sequence

def translate_index(df, transform):
    text_list = []
    for text in df:
        text_list.append(torch.tensor(transform(text), dtype=torch.int64))
    text_tensor = pad_sequence(text_list, batch_first=True, padding_value=1)
    return text_tensor
ja_train_tensor = translate_index(df_train['Japanese'], transform_ja)
ja_val_tensor = translate_index(df_test['Japanese'], transform_ja)
en_train_tensor = translate_index(df_train['English'], transform_en)
en_val_tensor = translate_index(df_test['English'], transform_en)

print(ja_train_tensor.shape)
print(ja_val_tensor.shape)
print(en_train_tensor.shape)
print(en_val_tensor.shape)

torch.Size([36379, 31])
torch.Size([15592, 28])
torch.Size([36379, 68])
torch.Size([15592, 73])

###DataLoaderの作成

from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(ja_train_tensor, en_train_tensor)
val_dataset = TensorDataset(ja_val_tensor, en_val_tensor)
n_val = int(len(val_dataset) * 0.6)
n_test = len(val_dataset) - n_val
# ランダムに分割を行うため、シードを固定して再現性を確保
pl.seed_everything(0)

# データセットの分割
val, test = torch.utils.data.random_split(val_dataset, [n_val, n_test])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val, batch_size=32)
test_loader = DataLoader(test, batch_size=32)

#Encoderの各層を定義
###EncoderのEmbedding層
ポイントはEmbedding 層の入力は入力値のボキャブラリ数と等しくする

src_vocab_length = len(vocab_ja)
d_model = 512

src_embedder = nn.Embedding(src_vocab_length, d_model)
embeded = src_embedder(src)
embeded.shape

torch.Size([32, 31, 512])

###Positional Encoder層

class PositionalEncoder(pl.LightningModule):

    def __init__(self, d_model=512, max_seq_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=0.1)
        self.d_model = d_model
        
        # 0 の行列を作成(Sequence_length, Embedding_dim)
        pe = torch.zeros(max_seq_len, d_model)

        # pe に位置情報が入った配列を追加
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                # 配列中の0 と偶数インデックスには sin 波を適用
                pe[pos, i] = math.sin(pos / 10000.0 ** ((2 * i) / d_model))
                # 配列中の奇数インデックスには cos 波を適用
                pe[pos, i + 1] = math.cos(pos / 10000.0 ** ((2 * (i + 1)) / d_model))

        pe = pe.unsqueeze(1)
        # print(f'PE のサイズ: {pe.shape}')

        # PE を pe という名前でモデルに保存
        self.register_buffer('pe', pe)

    def forward(self, x):
        # 埋め込み表現の値に sqrt を掛け値を大きくする
        x = x * math.sqrt(self.d_model)

        # 元の埋め込み表現に pe を足し合わせ位置情報を付加
        x = x + self.pe[:x.size(0), :]
        x = self.dropout(x)
        return x

#TransformerのEncoder層

# EncoderLayer
encoder_layer = nn.TransformerEncoderLayer(
    d_model, nhead=8,
    dim_feedforward=2048,
    dropout=0.1,
    activation='relu',
    batch_first=True
)
encoder_layer
# LayerNorm
encoder_norm =  nn.LayerNorm(d_model)

nn.TransformerEncoderLayer() を 6 回繰り返す形に定義します。

encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=6, norm=encoder_norm)
encoder

#Decorderの各層を定義
###DecoderのEmbedding層

trg_vocab_length = len(vocab_en)

trg_embedder = nn.Embedding(trg_vocab_length, d_model)

embeded = trg_embedder(trg_input)
embeded.shape

###Positional Encorder層

pos_encoder = PositionalEncoder(d_model)

pos_embeded = pos_encoder(embeded)

#TransformerのDecoder層

# tgt_key_padding_mask
trg_pad_idx = vocab_en['<pad>']

def create_trg_pad_mask(trg):
        trg_pad_mask = trg == trg_pad_idx
        return trg_pad_mask
trg_pad_mask = create_trg_pad_mask(trg_input)
trg_pad_mask.shape

torch.Size([32, 67])

#Decoderの出力層

out = nn.Linear(d_model, trg_vocab_length)

logit = out(dec_out)
logit.shape

torch.Size([32, 67, 30664])
出力の値から argmax をかければ、予測値がわかります。

y_softmax = F.softmax(logit, dim=-1)

pred = y_softmax.max(axis=-1)[1][0]
'''
どのような文字列を予測したのか見たければ .lookup_token() を使用します。

```Python
print(vocab_en.lookup_token(pred[0]))

#損失の計算
分類問題なので、CrossEntoropyを使用します。

# 先頭の <sos> は目標値に含まない
targets = trg[:, 1:].reshape(-1)
y = logit.view(-1, logit.size(-1))

loss = F.cross_entropy(y, targets, ignore_index=trg_pad_idx)

#Transformerのネットワークを定義

class Transformer(pl.LightningModule):

    def __init__(self, src_vocab_length=10000, trg_vocab_length=10000, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6,
                  dim_feedforward=2048, dropout=0.1, activation="relu", src_pad_idx=1, trg_pad_idx=1):
        super().__init__()

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx

        self.encoder = Encoder(src_vocab_length, d_model, nhead, dim_feedforward, num_encoder_layers, dropout, activation)
        self.decoder = Decoder(trg_vocab_length, d_model, nhead, dim_feedforward, num_decoder_layers, dropout, activation)

        self.out = nn.Linear(d_model, trg_vocab_length)
        
        # Xavier の初期値を使う場
        #   self.reset_parameters()
        # def reset_parameters(self):
        #     for param in self.parameters():
        #         if param.dim() > 1:
        #             nn.init.xavier_uniform_(param)


    def create_pad_mask(self, input_word, pad_idx):
        pad_mask = input_word == pad_idx
        return pad_mask


    def generate_square_subsequent_mask(self, size):
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask==1, float(0.0)).to(device)
        return mask

        
    def forward(self, src, trg):
        trg_input = trg[:, :-1]

        # 各種 Mask
        src_pad_mask = self.create_pad_mask(src, self.src_pad_idx)
        trg_pad_mask = self.create_pad_mask(trg_input, self.trg_pad_idx)
        trg_mask = self.generate_square_subsequent_mask(trg_input.size(1))

        memory = self.encoder(src, src_pad_mask)
        output = self.decoder(memory, trg_input, trg_mask, trg_pad_mask)
        
        logit = self.out(output)
        return logit


    def training_step(self, batch, batch_idx):
        src, trg = batch

        logit = self(src, trg)
        
        targets = trg[:, 1:].reshape(-1)
        y = logit.view(-1, logit.size(-1))

        # ignore_index : 損失計算で <pad> のクラスを省く
        loss = F.cross_entropy(y, targets, ignore_index=self.trg_pad_idx)
        self.log('train_loss', loss, on_step=False, on_epoch=True)

        return loss


    def validation_step(self, batch, batch_idx):
        src, trg = batch

        logit = self(src, trg)
        
        targets = trg[:, 1:].reshape(-1)
        y = logit.view(-1, logit.size(-1))

        loss = F.cross_entropy(y, targets, ignore_index=self.trg_pad_idx)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        
        return loss


    def test_step(self, batch, batch_idx):
        src, trg = batch

        logit = self(src, trg)
        
        targets = trg[:, 1:].reshape(-1)
        y = logit.view(-1, logit.size(-1))

        loss = F.cross_entropy(y, targets, ignore_index=self.trg_pad_idx)
        self.log('test_loss', loss, on_step=False, on_epoch=True)
        return loss


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
        return optimizer

#学習の実行

# 乱数シードの固定
pl.seed_everything(0)

src_vocab_length = len(vocab_ja)
trg_vocab_length = len(vocab_en)
src_pad_idx = vocab_ja['<pad>']
trg_pad_idx = vocab_en['<pad>']

# インスタンス化
net = Transformer(
    src_vocab_length=src_vocab_length,
    trg_vocab_length=trg_vocab_length,
    src_pad_idx=src_pad_idx,
    trg_pad_idx=trg_pad_idx
    )

trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(net, train_loader, val_loader)
result = trainer.test(test_dataloaders=test_loader)

DATALOADER:0 TEST RESULTS
{'test_loss': 5.714888095855713}

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?