9
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

日本語Pre-Trained BERTを用いた対話Botの作成

Last updated at Posted at 2022-08-06

この頃機会があったので、私が過去の研究で作成した、日本語BERTを使ったAttention付きSeq2Seqモデルのソースコード群をまとめてみました。
同じようなことをされたい方の一助になればと思います。

作成したソースコード、モデルパラメータ、N-gram言語モデルはGitHubにて公開しています。
リンクは記事のまとめに記載しています。

目次

  1. 概要
  2. 環境構築・事前準備
    2.1. 日本語BERTの準備
    2.2. 各ライブラリのインストール
     2.2.1. Juman++ のインストール
     2.2.2. KenLM のインストール
    2.3. N-gram言語モデルの構築
    2.4. 学習データの作成
     2.4.1. 学習データからの未知語除去
  3. モデル構造
  4. ソースコード
    4.1. Encoder
    4.2. Attenstion
    4.3. Decoder
    4.4. Model synthesize
    4.5. Dataloader
    4.6. Scorer
    4.7. Logger Generator
    4.8. Trainer (main.py)
  5. 学習結果
  6. 実際の対話
    6.1. 対話
  7. まとめ
  8. 参考文献

1. 概要

2020年当時、高専時代の卒業研究でBERTを使った対話システムの作成に取り組んでおり、その練習として、BERTをAttention付きSeq2Seqのエンコーダに用いた対話システムの作成に取り組んでいました。
基本的な構造は、2層LSTMのAttention付きEncoder-DecoderモデルのEncoderにBERTを追加したものになります。DecoderへのLSTM状態の入力はLSTMが担い、AttentionのKey-Valueへの入力はBERTが行います。

2. 環境構築・事前準備

実行環境

  • OS:Ubuntu 20.04
  • Python 3.7.13 (conda 4.12.0)
  • CUDA 11.1
  • cuDNN 8.4.1

Python ライブラリ

  • PyTorch 1.8.2 +CUDA 11.1
  • protbuf 3.19.4
  • transformers 3.4.0
  • pyknp 0.6.1
  • kenlm 0.0.0

ディレクトリ構造

dialog_sys
├── data
│   ├── eval
│   └── train
├── log
├── ngram
│   ├──bert
│   └── scoring
│       └── models
├── nn
├── params
├── resource
│   └── bert
└── utils

 
使用したコーパス

  • Twitterのツイートとリプライのペア
  • 落語コーパス
  • 名大コーパス
  • 日本語Wikipedia(N-gramの構築のみ)

2.1. 日本語BERTの準備

今回使用するBERTは、京都大学の黒橋・褚・河原研究室で公開されている学習済みモデル(BASE WWM版)を利用します。
BERT日本語Pretrainedモデル - 黒橋・褚・村脇研究室

ダウンロードして解凍したBERTモデルは resource/bert に移します。
また、解凍したディレクトリの中にはvocab.txtがあるので、 ngram/bert にコピーしておきます。
これでBERTのインストールは完了です。

2.2. 各ライブラリのインストール

京大BERTを利用するにあたり、形態素解析器としてJuman++をインストールします。
必要Pythonライブラリで示した pyknp は Python コード内でJuman++を使用するためのライブラリです。

さらに、対話の応答生成では、応答の自然さ向上のため、N-gram言語モデルによる応答文のスコアリングを行います。
そのため、KenLMをインストールし、N-gram言語モデルの構築も行います。

次節より、各ライブラリのインストール作業を始めます。
各ライブラリのインストールにあたり、cmakeが入ってない方はcmakeのインストールを行ってください。

$ sudo apt-get install build-essential libboost-all-dev cmake zlib1g-dev libbz2-dev liblzma-dev
$ sudo apt install cmake

2.2.1. Juman++ のインストール

$ cd ~/
$ wget https://github.com/ku-nlp/jumanpp/releases/download/v2.0.0-rc2/jumanpp-2.0.0-rc2.tar.xz
$ tar xfv jumanpp-2.0.0-rc2.tar.xz  
$ cd jumanpp-2.0.0-rc2
$ mkdir bld
$ cd bld
$ sudo cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/local
$ sudo make install -j2

動作確認

$ echo "外国人参政権" | jumanpp
外国 がいこく 外国 名詞 6 普通名詞 1 * 0 * 0 "代表表記:外国/がいこく ドメイン:政治 カテゴリ:場所-その他"
人 じん 人 名詞 6 普通名詞 1 * 0 * 0 "代表表記:人/じん カテゴリ:人 漢字読み:音"
@ 人 じん 人 名詞 6 普通名詞 1 * 0 * 0 "代表表記:人/ひと カテゴリ:人 漢字読み:訓"
参政 さんせい 参政 名詞 6 サ変名詞 2 * 0 * 0 "代表表記:参政/さんせい ドメイン:政治 カテゴリ:抽象物"
権 けん 権 名詞 6 普通名詞 1 * 0 * 0 "代表表記:権/けん カテゴリ:抽象物 漢字読み:音"
EOS

2.2.2. KenLM のインストール

$ cd ~/
$ wget -O - https://kheafield.com/code/kenlm.tar.gz |tar xz
$ mkdir kenlm/build
$ cd kenlm/build
$ sudo cmake ..
$ sudo make -j2

$ cd ~/
$ chmod u+x ./kenlm/bin/*

(ディレクトリ構造で示した ngram ディレクトリ下に build ディレクトリをコピー)
$ cp -r kenlm/build /path/to/dialog_sys/ngram/
$ mv build kenlm

実際にpythonライブラリとしてkenlmを使用する際には、Python仮想環境(Anacondaなど)を立ち上げた上で kenlm ディレクトリ下にある setup.py を実行する必要があります。

$ python setup.py install

これで必要なライブラリのインストールは完了です。

2.3. N-gram言語モデルの構築

先ほどインストールしたKenLMを用いて応答文スコアリング用のN-gram言語モデルを作成します。
事前にコーパスを用意したうえで読み進めてください。

ここで使用するコーパスデータは1行に1文の形式を想定しています。
(対話システムモデルの学習に使うデータは、発話と応答のペア形式で別形式なので、注意してください)

まず、用意したデータに前処理を行います。
ここでは、データのファイル名を inputs.txt としています。

preprocess.py
import re
from pyknp import Juman
from transformers import BertTokenizer
import torch

jumanpp = Juman()

# GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device : ", device)

# parameters
bert_seq_len = 512

class JumanTokenizer():
    def __init__(self):
        self.juman = Juman()

    def tokenize(self, text):
        result = self.juman.analysis(text)
        return [mrph.midasi for mrph in result.mrph_list()]

class BertWithJumanModel(BertTokenizer):
    def __init__(self, vocab_file_name="vocab.txt"):
        super().__init__(
            vocab_file_name,
            do_lower_case=False,
            do_basic_tokenize=False
        )
        self.juman_tokenizer = JumanTokenizer()

    def _preprocess_text(self, text):
        return text.replace(" ", "")  # for Juman

    def __call__(self, text):
        preprocessed_text = self._preprocess_text(text)
        tokens = self.juman_tokenizer.tokenize(preprocessed_text)
        bert_tokens = self.tokenize(" ".join(tokens))

        return bert_tokens

bert_tokenizer = BertWithJumanModel('bert/vocab.txt')

with open('inputs.txt', 'r') as fi, open('inputs_tknz.txt', 'w') as fo:
    for line in fi:
        for sentence in re.split('[,.、。,.]', line):
            sentence = re.sub('\n', '', sentence)
            fo.write(' '.join(bert_tokenizer(sentence)) + '\n')

preprocess.py では用意したコーパスをさらに読点で分割し、 Juman++ & BertTokenizer で形態素解析して出力します。
読点で分割するのは、対話システムが生成した応答を評価する際、読点の有無で評価結果(出現確率)が揺れることを防ぐためです。

例えば、3-gram 言語モデルは以下の式で表されますが、

Trigram LM = p(w_1^{N}) = \prod_{n=1}^{N} p(w_n|w_{n-2}, w_{n-1})

この $ w_{n-1} $ が読点である場合、ない場合と比較して評価結果が悪化する場合があります。
読点は語句に比べ、省略と挿入に関して文章の書き手にある程度の自由が存在し、出現パターンが一定ではないことが原因ではないかと考えられます。

そこで単純な解決策として、読点で文を分割し、ある程度の語彙のまとまりを一つの文として扱う方法を採用しています。

preprocess.py の実行後、同じディレクトリ内に inputs_tknz.txt が作成されているはずなので、この inputs_tknz.txt を使用しN-gram言語モデルの構築を行います。

$ cat inputs_tknz.txt |./kenlm/bin/lmplz -o 3 > ngram.arpa

lmplz のオプション -o はモデルの次数(order)を指定するものです。
今回は 3-gram なので -o 3 としています。

出来上がった ngram.arpa は、高速化のためバイナリ形式に変換します。

$ ./kenlm/bin/build_binary trie ngram.arpa ngram.binary

出力された ngram.binary は dialog_sys/ngram/scoring/models ディレクトリ下に配置します。

これでN-gram言語モデルの構築は完了です。

2.4. 学習データの作成

学習データは dialog_sys/data ディレクトリ下の train に訓練データを、eval に検証データを配置します。

今回、学習で使用するデータの形式は以下の通りになっています。

  • ヘッダーにファイルが持つデータのレコード数(行数)
  • 各レコードは「発話文 , 応答文」の形式。
  • 発話・応答文はそれぞれ、ID化された形態素がスペース区切りで並ぶ。

以下に例を示します。(データは公開できないため数字は適当です)

data0.txt
3
2 52 40 56 789 321 7 3,2 78 54 2 67 7 3
2 76 7324 567 3,2 5532 44 901 7 3
2 690 1702 7 3,2 590 112 68 7 3

DataLoaderで読み込む際には、以下の流れになります。

  1. 先頭を見てレコード数を把握
  2. 目的のレコードを読み出す
  3. 「,」で分割(発話文、応答文の分割)
  4. スペースで分割(形態素ごとに分割)

また、精度が良い代わりに動作が重い Jumann++ を形態素解析器として利用するため、学習の高速化の観点から、学習データの形態素解析とID化は事前に済ませておきます。

2.4.1. 学習データからの未知語除去

(こちらは、興味のある方のみ取り組んでみてください。)

形態素解析結果には未知語が含まれ、BPEを適用した場合でも5%前後の確率で未知語が発生します。
学習データに未知語が混入すると、当然生成される応答文にも未知語が出現するリスクがあります。
そこで、transformers の BertForMaskedLM を用いて未知語である [UNK]トークンを [MASK]トークンに置き換えて単語予測を行います。こうすることで、学習データから未知語を削除します。
こちらは先ほど述べた形式のデータに次のプログラムを適用することで達成できます。

dialog_sys/utils ディレクトリ下に配置してください。

replace_unk.py
import torch
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence, pad_sequence
from transformers import BertTokenizer, BertForMaskedLM, BertConfig
import numpy as np

import os

config = BertConfig.from_json_file('resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/bert_config.json')
model = BertForMaskedLM.from_pretrained('resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/pytorch_model.bin', config=config)
bert_tokenizer = BertTokenizer('resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/vocab.txt', do_lower_case=False, do_basic_tokenize=False)
model.eval()

# GPUのセット
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 三項演算子
print("使用デバイス:", device)
model.to(device)

output_train_path = "data/train_cleaned"
output_eval_path = "data/eval_cleaned"
train_path = "data/train"
eval_path = "data/eval"

unk_id = bert_tokenizer.convert_tokens_to_ids('[UNK]')
msk_id = bert_tokenizer.convert_tokens_to_ids('[MASK]')
pad_id = bert_tokenizer.convert_tokens_to_ids('[PAD]')
sep_id = bert_tokenizer.convert_tokens_to_ids('[SEP]')
cls_id = bert_tokenizer.convert_tokens_to_ids('[CLS]')

def data_recreate(in_path, out_path, head_idx=0, batch_size=8, word_replace=False):
    """[UNK]を[MASK]に置換し, BERTのMaskedLanguageModelにて穴埋めする.

    Args:
        in_path (_type_): 入力データのディレクトリ.
        out_path (_type_): 出力データのディレクトリ.
        head_idx (int, optional): 入力データの読み込み開始位置. Defaults to 0.
        batch_size (int, optional): BERTに入力する際のバッチサイズ. Defaults to 8.
        word_replace (bool, opyional): [MASK]以外の単語の置き換えを行うか指定する. Defaults to False.
    """
    files = os.listdir(in_path)
    file_num = len(files)

    file_exists = os.path.exists(out_path)
    if file_exists == False:
        #保存先のディレクトリの作成
        os.mkdir(out_path)

    for i, file_name in enumerate(files):
        path = os.path.join(in_path, file_name)
        print('processing for {} ... '.format(path), end='')
        
        with open(path, mode='r') as f:
            lines = f.read().splitlines()
            
        headders = lines[:head_idx]
        
        querys = []
        answers = []
        for line in lines[head_idx:]:
            # convert str -> int
            query_str, answer_str = line.split(',')
            
            query_str = query_str.split()
            answer_str = answer_str.split()
            
            query = [int(q) for q in query_str]
            answer = [int(a) for a in answer_str]
            
            querys.append(query)
            answers.append(answer)
            
        # shape data
        packs = pack_sequence([torch.tensor(t, device=device).detach() for t in answers], enforce_sorted=False)
        (answers, answers_len) = pad_packed_sequence(
            packs, 
            batch_first=True, 
            padding_value=0.0
        )
        packs = pack_sequence([torch.tensor(t, device=device).detach() for t in querys], enforce_sorted=False)
        (querys, querys_len) = pad_packed_sequence(
            packs, 
            batch_first=True, 
            padding_value=0.0
        )
        
        # setting device
        querys.to(device)
        answers.to(device)
        
        if word_replace:
            # record [PAD] position
            querys_pad = (querys == pad_id)
            answers_pad = (answers == pad_id)
            # record [SEP] position
            querys_sep = (querys == sep_id)
            answers_sep = (answers == sep_id)
            # record [CLS] position
            querys_cls = (querys == cls_id)
            answers_cls = (answers == cls_id)
        
        # record [UNK] position
        querys_unk = (querys == unk_id)
        answers_unk = (answers == unk_id)
        
        # unk -> mask
        querys[querys_unk] = msk_id
        answers[answers_unk] = msk_id
        
        # divide batch (to avoid out of memory)
        querys_div = torch.split(querys, batch_size)
        answers_div = torch.split(answers, batch_size)
        
        # bert masked lm
        new_querys = torch.zeros((1, querys.shape[1]), device=device)
        new_answers = torch.zeros((1, answers.shape[1]), device=device)
        for query_batch, answer_batch in zip(querys_div, answers_div):
            # source mask
            querys_mask = torch.ones(query_batch.shape ,dtype=torch.int32, device=device) * (query_batch != 0)
            answers_mask = torch.ones(answer_batch.shape ,dtype=torch.int32, device=device) * (answer_batch != 0)
            
            # inference
            pre_qry_batch = model(query_batch, attention_mask=querys_mask)
            pre_ans_batch = model(answer_batch, attention_mask=answers_mask)
            
            # collect inference result
            _, qry_batch = torch.topk(pre_qry_batch[0], k=1, dim=2)
            qry_batch = torch.squeeze(qry_batch, dim=2)
            _, ans_batch = torch.topk(pre_ans_batch[0], k=1, dim=2)
            ans_batch = torch.squeeze(ans_batch, dim=2)
            
            # concatnate
            new_querys = torch.cat((new_querys, qry_batch), dim=0)
            new_answers = torch.cat((new_answers, ans_batch), dim=0)
        
        # shape and cast
        new_querys = new_querys[1:].to(torch.int64) # remove zeor tensor
        new_answers = new_answers[1:].to(torch.int64)
        
        if not word_replace:
            querys[querys_unk] = new_querys[querys_unk] # replace only [UNK] position
            answers[answers_unk] = new_answers[answers_unk]
            querys = querys.to(torch.int)
            answers = answers.to(torch.int)
        else:
            querys = new_querys # replace all word
            answers = new_answers
            # recover [CLS] and [SEP] and [PAD]
            querys[querys_cls] = cls_id
            answers[answers_cls] = cls_id
            querys[querys_sep] = sep_id
            answers[answers_sep] = sep_id
            querys[querys_pad] = pad_id
            answers[answers_pad] = pad_id
        
        # write for new file
        path = os.path.join(out_path, file_name)
        with open(path, mode='w') as f:
            # write headder
            if type(headders) == list:
                for headder in headders:
                    f.write(headder + '\n')
            elif type(headders) == str:
                f.write(headders + '\n')
            
            # write data
            for query, answer in zip(querys, answers):
                # remove [PAD]
                qry_pad_mask = (query != pad_id)
                ans_pad_mask = (answer != pad_id)
                query = query[qry_pad_mask]
                answer = answer[ans_pad_mask]
                
                # convert tensor -> str
                query = ' '.join([str(q) for q in query.tolist()])
                answer = ' '.join([str(a) for a in answer.tolist()])
                
                # write record
                f.write('{},{}\n'.format(query, answer))
                
        print('Done.')
if __name__ == '__main__':
    data_recreate(train_path, output_train_path, 1)
    data_recreate(eval_path, output_eval_path, 1)

実行は、dialog_sysディレクトリに移動し、以下のように行います。

$ python utils/replace_unk.py

BERTモデルとデータセットが適切に配置されていれば問題なく実行され、dialog_sys/data ディレクトリにtrain_cleaned, eval_cleaned の2つが生成されているはずなので、train, eval にそれぞれ名前を変えておいてください。

以上で事前準備は終了です。

3. モデル構造

今回作成するモデルの概要は以下のようになっています。
dialog_sys_fig.png
基本的な構造は、Attention付き Seq2Seqモデルの Attentionへの Key-Valueの入力を BERTから取り出された特徴ベクトルが担うといった形になります。
Encoder側の LSTMは文脈情報を Decoder側の LSTMに伝達することが主な役割です。
Decoder側の LSTMは、BERTを用いた Attentionの情報を参照しながら応答を生成・出力します。

なお、損失関数にはCrossEntropyLossとKLDivLossを用いています。
CrossEntropyLoss には教師データと出力の誤差を計算させ、KLDivLoss には BertForMaskedLM と出力分布の誤差を計算させることで、one-hotラベルの教師データで学習するのに比べて、柔軟に学習を行うことが目的です。

4. ソースコード

4.1. Encoder

dialog_sys/nn ディレクトリ下に配置します。

encoder.py
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence

import torch
import torch.nn as nn

from transformers import BertModel

def build_encoder(num_layers, bert_model, out_features, lstm_features, bidrectional, enc_len, device='cuda:0'):
    """Generate encoder model.

    Args:
        num_layers (int): LSTM layer number.
        bert_model (BertModel): used Bert model.
        out_features (int): encoder output feature size.
        lstm_features (int): lstm hidden size.
        bidrectional (bool): use bidirectional lstm when bidirectional=True
        enc_len (int): max encoder sequence length.
        device (str): device type. Defaults to 'cuda:0'.

    Returns:
        Encoder(nn.Module): encoder module
    """
    model = Encoder(
        bert_model=bert_model,
        out_features=out_features,
        lstm_hidden_size=lstm_features,
        num_layers=num_layers,
        bidirectional=bidrectional,
        enc_len=enc_len,
        device=device
    )

    return model

class BertEmbedding(nn.Module):
    """
    same nn.Embedding(num_input=32006, emb_size=out_features, padding_idx=0)
    """
    def __init__(self, bert_model=None, out_features=256, device='cuda:0'):
        super(BertEmbedding, self).__init__()

        # Parameters ######################################
        self.out_features = out_features
        self.device = device
        ###################################################

        # Layers ##########################################
        self.bert = bert_model
        self.freez(self.bert) # freeze auto grad
        self.transfer_learning(self.bert)

        self.pooler_dense = nn.Linear(768, self.out_features)
        self.layer_norm = nn.LayerNorm(self.out_features)
        ###################################################

    def forward(self, inputs:torch.tensor):
        """
        Bert Embedding or Bert Encoder Module.

        Attribute
        ---------
        input : 
            type : torch.tensor(int)
            shape : [batch_size, seq_len]
            Input sequence list which is refilled [PAD].
            You need to analyse with proper morphological analyzer for which you use bert model.

        Return
        ------
        output :
            type : torch.tensor(float)
            shape : [batch_size, seq_len, out_features]
            Embedding expression by BERT.
            Padding indices is initialized by 0.0.
        """

        # Generate mask ###################################
        self.source_mask = torch.ones(inputs.shape, dtype=int, device=self.device) * (inputs != 0)
        ###################################################

        # Layers : BERT -> Linear -> LayerNorm ############
        output, _ = self.bert(inputs, self.source_mask)
        output = self.pooler_dense(output)
        output = self.layer_norm(output)
        ###################################################

        # Embedding mask to padding index #################
        self.source_mask = self.source_mask.to(torch.float32)
        output = output * self.source_mask.unsqueeze(-1) 
        ###################################################

        return output

    def freez(self, module):
        for p in module.parameters():
            p.requires_grad = False

    def unfreez(self, module):
        for p in module.parameters():
            p.requires_grad = True
            
    def transfer_learning(self, module:BertModel):
        module.pooler.dense = nn.Linear(768, 768, bias=True)
        module.pooler.dense.requires_grad = True

class Encoder(nn.Module):
    def __init__(self, bert_model=None, out_features=256, lstm_hidden_size=256, num_layers=1, bidirectional=False, enc_len=128, device='cuda:0'):
        """initialize encoder LSTM

        Args:
            out_features (int): output feature size. Defaults to 256.
            lstm_hidden_size (int): lstm hidden vector size. Defaults to 256.
            num_layers (int): lstm layer number. Defaults to 1.
            bidirectional (bool): Defaults to False
            enc_len (int): encoder's sequence length. Defaults to 128.
            device (str): device type. Defaults to 'cuda:0'.
        """
        super(Encoder, self).__init__()

        self.num_layers = num_layers
        self.out_features = out_features
        self.lstm_hidden_size = lstm_hidden_size
        self.bidirectional = bidirectional
        self.enc_len = enc_len
        self.device = device

        self.bert_embed = BertEmbedding(
            bert_model=bert_model,
            out_features=self.out_features,
            device=self.device
        )
        self.embedding = nn.Embedding(32006, 768, padding_idx=0)
        self.input_dense = nn.Linear(768, self.out_features)
        self.input_norm = nn.LayerNorm(self.out_features)
        self.input_drop = nn.Dropout(p=0.2)

        self.lstm = nn.LSTM(
            input_size=self.out_features,
            hidden_size=self.lstm_hidden_size,
            batch_first=True,
            num_layers=self.num_layers,
            bidirectional=self.bidirectional,
        )
        
        directions = 2 if self.bidirectional else 1
        self.directions = directions
        
        d_h = [nn.Linear(self.lstm_hidden_size*directions, self.lstm_hidden_size).to(device) for i in range(num_layers)]
        d_c = [nn.Linear(self.lstm_hidden_size*directions, self.lstm_hidden_size).to(device) for i in range(num_layers)]
        
        self.dense_h = nn.ModuleList(d_h)
        self.dense_c = nn.ModuleList(d_c)
        
        self.dense = nn.Linear(self.out_features, self.out_features)
        
        normh = [nn.LayerNorm(self.lstm_hidden_size).to(device) for i in range(num_layers)]
        normc = [nn.LayerNorm(self.lstm_hidden_size).to(device) for i in range(num_layers)]
        
        self.layer_normh = nn.ModuleList(normh)
        self.layer_normc = nn.ModuleList(normc)

        self.layer_norm = nn.LayerNorm(self.out_features)

    def forward(self, ids:list):
        """
        Attributes
        ----------
        ids : 
            type : list
            shape : [batch_size, seq_len]
            Sentence whose token changed ids.
        """
        
        batch_size = len(ids)

        ids = [seq[-self.enc_len:] for seq in ids]
        packs = pack_sequence([torch.tensor(t, device=self.device) for t in ids], enforce_sorted=False)
        (model_input, lengths_info) = pad_packed_sequence(
            packs, 
            batch_first=True, 
            padding_value=0.0
        )
        

        # BERT Embedding Layer ############################
        outputs = self.bert_embed(model_input)
        ###################################################

        # Generate mask ###################################
        dense_mask = self.bert_embed.source_mask.detach()
        dense_mask = dense_mask.unsqueeze(-1)
        ###################################################

        # Embedding for LSTM ##############################
        lstm_input = self.embedding(model_input)
        lstm_input = self.input_dense(lstm_input)
        lstm_input = self.input_norm(lstm_input)
        lstm_input = self.input_drop(lstm_input)
        ###################################################

        # Don't Need mask ! (hint : length_info)
        lstm_packs = pack_padded_sequence(
            lstm_input,
            lengths=lengths_info, 
            batch_first=True,
            enforce_sorted=False
        )

        # LSTM Layer ####################################
        _, (hn, cn) = self.lstm(lstm_packs)
        
        hn = hn.view([self.directions, self.num_layers, batch_size, self.lstm_hidden_size])
        hn = torch.transpose(hn, 1, 2)
        hn = torch.transpose(hn, 0, 2)
        hn = hn.reshape([self.num_layers, batch_size, self.directions*self.lstm_hidden_size])
        
        cn = cn.view([self.directions, self.num_layers, batch_size, self.lstm_hidden_size])
        cn = torch.transpose(cn, 1, 2)
        cn = torch.transpose(cn, 0, 2)
        cn = cn.reshape([self.num_layers, batch_size, self.directions*self.lstm_hidden_size])
        
        batch_size = len(outputs)
        ###################################################

        # Pooler ##############################

        outputs = self.dense(outputs) # one Attention use.
        outputs = self.layer_norm(outputs)
        outputs = outputs*dense_mask

        hn = [dense(h) for dense, h in zip(self.dense_h, hn)]
        hn = [norm(h) for norm, h in zip(self.layer_normh, hn)]
        hn = torch.cat([h.unsqueeze(0) for h in hn], dim=0)

        cn = [dense(c) for dense, c in zip(self.dense_c, cn)]
        cn = [norm(c) for norm, c in zip(self.layer_normc, cn)]
        cn = torch.cat([c.unsqueeze(0) for c in cn], dim=0)

        return (outputs, hn, cn)

BertEmbeddingクラスに記述してある、freez() でBERTモデルの学習を凍結しています。
また、transfer_learning() は転移学習を実現する関数です。

Encoderの動作を簡単にまとめると、

  1. 形態素解析と ID化済みの発話文の形態素系列をリストとして受け取り、
  2. BERTによる特徴ベクトルの系列と LSTMの最終出力 hn, cn を出力する

となっています。

また、Encoderの LSTMのレイヤー数、双方向・単方向は可変で、後程紹介するコマンドライン引数で変更することができます。
今回学習するEncoderは、2層の双方向LSTMです。

4.2. Attenstion

dialog_sys/nn ディレクトリ下に配置します。

attention.py
import math

import torch
import torch.nn as nn

def build_attention(h_size, enc_size, dec_size, device):
    attention_layer = AttentionLayer(
        h_size=h_size,
        enc_size=enc_size,
        dec_size=dec_size,
        device=device
    )
    return attention_layer

class AttentionLayer(nn.Module):
    def __init__(self, h_size, enc_size, dec_size, device='cuda:0'):
        super(AttentionLayer, self).__init__()

        # Parameters ######################################
        self.h_size = h_size
        self.enc_size = enc_size
        self.dec_size = dec_size
        self.scale_factor = math.sqrt(self.h_size)
        self.device = device
        ###################################################

        # Layers ##########################################
        self.query_dense = nn.Linear(self.h_size, self.h_size)
        self.key_dense = nn.Linear(self.dec_size, self.h_size)
        self.value_dense = nn.Linear(self.dec_size, self.h_size)

        self.softmax = nn.Softmax(dim=2)

        self.output_dense = nn.Linear(self.h_size, self.h_size)
        self.shape_dence = nn.Linear(self.h_size + self.dec_size, self.h_size)

        self.layer_norm = nn.LayerNorm(self.h_size)
        ###################################################

    def forward(self, encoder_out:torch.tensor, decoder_out:torch.tensor):
        """
        Attributes
        ----------
        encoder_out : 
            type : tensor(torch.float32)
            shape : [batch_size, seq_len, h_size]
            This is encoder outputs.
            Its roll in Attention is Key and Value.
        decoder_out :
            type : tensor(torch.float32)
            shape : [N, batch_size, h_size]
            This is decoder outputs.
            Its roll in Attention is Query.

        Returns
        -------
        final_output : 
            type : tensor(torch.float32)
            shape : [N, batch_size, h_size]
            Context vector + decoder_output.
        """

        key = encoder_out.clone().to(device=self.device)
        value = encoder_out.clone().to(device=self.device)
        query = decoder_out.clone().to(device=self.device)

        # Dense Layers ####################################
        key = self.key_dense(key)
        value = self.value_dense(value)
        query = self.query_dense(query)

        query = query/self.scale_factor
        ###################################################

        # Mat Mul Layer ###################################
        logit = torch.bmm(query, key.transpose(1,2))

        attention_weight = self.softmax(logit)

        output = torch.bmm(attention_weight, value)
        context_vec = self.output_dense(output)
        ###################################################

        # Add Layer #######################################
        final_out = torch.cat((decoder_out, context_vec), axis=-1)
        final_out = self.shape_dence(final_out)
        final_out = self.layer_norm(final_out)
        ###################################################

        return final_out

Attentionレイヤーは以下の記事を参考に組み立てました。
作って理解する Transformer / Attention

Attentionは次の式で表現されます。

Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V

$Q$ がDecoderの出力、$K$ と $V$ がEncoder(今回はBERT)のベクトル系列(=行列)です。
$d_k$ は $Q$ と $K$ の次元数で、softmax の logit の値が大きくなりすぎて飽和してしまうことを防ぐ役割があります。
attention.py 中では self.scale_factor という名前の変数が $ \sqrt{d_k} $ にあたります。

Attention が行っている計算は単純で、$Q$ と $K$ の類似度(のようなもの)を $QK^T$ で算出し、その類似度を重みとして $V$ を加重平均することです。
こうして取り出された $V$ の加重平均はコンテキストベクトル(文脈ベクトル)として利用されます。
今回のモデルに当てはめれば、EncoderのBERTから出力されたベクトル系列の各ベクトルを、DecoderのLSTMから出力されたベクトルの類似度で加重平均している...ということになります。
こうすることで入力系列全体から適切に情報を抽出し、応答生成に活かすことができるのです。

ここで類似度について補足しておきます。
ベクトルの内積の定義は

\mathbf{a}\mathbf{b} = |\mathbf{a}||\mathbf{b}|cos\theta

このように表現でき、つまりここでいう類似度は、「ベクトルとベクトルの向きの類似性」と言い換えることができそうです。
自然言語処理の分散表現(word2vecなど)においてベクトルの加減算で別の単語ベクトルが表現できたことを考えると納得の仕組みです。

4.3. Decoder

dialog_sys/nn ディレクトリ下に配置します。

decoder.py
import torch
import torch.nn as nn

from torch.nn.utils.rnn import pad_packed_sequence, pack_sequence

from .attention import build_attention

def build_decoder(num_layers, out_features, lstm_features, bert_hsize, bert_tokenizer, max_len, device='cuda:0'):
    decoder = Decoder(
        num_layers,
        out_features,
        lstm_features,
        bert_hsize,
        bert_tokenizer,
        max_len,
        device
    )

    return decoder

VOCAB_SIZE = 32006

class Decoder(nn.Module):
    def __init__(self, num_layers, out_features, lstm_features, bert_hsize, bert_tokenizer, max_len, device='cuda:0'):
        super(Decoder, self).__init__()

        self.bert_tokenizer = bert_tokenizer

        # Parameters ######################################
        self.num_layers = num_layers
        self.out_features = out_features
        self.lstm_features = lstm_features
        self.bert_hsize = bert_hsize
        self.device = device
        self.max_len = max_len
        ###################################################

        # Layers ##########################################
        self.embedding = nn.Embedding(VOCAB_SIZE, self.out_features)
        self.dense1 = nn.Linear(self.out_features, self.out_features)
        self.layer_norm0 = nn.LayerNorm(self.out_features)
        self.dropout = nn.Dropout(p=0.2)
        
        self.lstm = nn.LSTM(
            self.out_features, 
            self.lstm_features, 
            num_layers=self.num_layers, 
            batch_first=True, 
        )
        self.dense_lstm = nn.Linear(self.lstm_features, self.lstm_features)
        self.dence_out = nn.Linear(self.out_features, self.out_features)
        self.pooler = nn.Linear(self.out_features, VOCAB_SIZE)
        self.dence_feature = nn.Linear(self.out_features, self.bert_hsize)
        self.softmax = nn.Softmax(dim=2) # use only eval mode.
        
        self.attention = build_attention(
            h_size=self.out_features,
            enc_size = self.lstm_features,
            dec_size = self.lstm_features,
            device=self.device
        )
        ###################################################

    def forward(self, key_value, h:torch.tensor, c:torch.tensor, response, mode='train'):
        """

        Args:
            key_value (torch.tensor): Encoder output vector sequence
            h (torch.tensor): Encoder LSTM output vector
            c (torch.tensor): Encoder LSTM Cell vector
            response (list): response word ID sequence
            mode (str, optional): 'train' or 'eval'. Defaults to 'train'.

        Returns:
            (outputs, (hn, cn), out_fvs)
            outputs (torch.tensor): Model output for CrossEntropyLoss
            hn (torch.tensor): Encoder LSTM output vector
            cn (torch.tensor): Encoder LSTM Cell vector
            out_fvs (torch.tensor): Model output for KLDivLoss
        """

        packs = pack_sequence([torch.tensor(t, device=self.device) for t in response], enforce_sorted=False)
        (model_input, lengths_info) = pad_packed_sequence(
            packs, 
            batch_first=True, 
            padding_value=0.0
        )
        
        if response == None:
            self.batch_size = 1
            self.seq_len = self.max_len
        else:
            self.batch_size = len(model_input)
            self.seq_len = len(model_input[0])
        h = h.clone()
        c = c.clone()

        if mode == 'train':
            result = self.train_forward(key_value, h, c, model_input)
        elif mode == 'eval':
            result = self.eval_forward(key_value, h, c)
        else:
            raise Exception('Invalid Mode :',mode)

        return result

    def train_forward(self, key_value, h, c, response:torch.tensor):
        inputs = response

        # Embedding & Dropout & Dense ##################### 
        embed = self.embedding(inputs.clone())
        affin = self.dense1(embed)
        affin = self.layer_norm0(affin)
        drop = self.dropout(affin)
        ###################################################

        # LSTM with LayerNorm #############################
        outs, (hn, cn) = self.lstm(drop, (h, c))

        # Attention #########################
        cntxt = self.attention(key_value, outs)
        ###################################################

        # Pooling Layer ###################################
        cntxt = self.dence_out(cntxt)
        out_fvs = self.dence_feature(cntxt)
        pooler_outputs = self.pooler(cntxt)
        ###################################################

        return pooler_outputs, (hn, cn), out_fvs

    def eval_forward(self, key_value, h, c):
        hn = h
        cn = c

        count = 0
        outputs=torch.zeros([self.batch_size, 1, VOCAB_SIZE]).to(self.device)
        out_fvs=torch.zeros([self.batch_size, 1, self.bert_hsize]).to(self.device)

        word = '[CLS]'
        inputs = [self.bert_tokenizer.convert_tokens_to_ids(word)] * self.batch_size

        while count < self.max_len:
            # Tokenizer ###################################
            input_tensor = torch.tensor(inputs, device=self.device)
            input_tensor = input_tensor.view([self.batch_size, 1])
            ###############################################

            # Embedding & Dropout & Dense #################
            embed = self.embedding(input_tensor.clone())
            affin = self.dense1(embed)
            affin = self.layer_norm0(affin)
            drop = self.dropout(affin)
            ###############################################

            # LSTM with LayerNorm #########################
            out_lstm, (hn, cn) = self.lstm(drop, (hn, cn))

            # Attention #########################
            cntxt = self.attention(key_value, out_lstm)
            ###############################################

            # Pooling Layer ###############################
            cntxt = self.dence_out(cntxt)
            dense_fvs = self.dence_feature(cntxt)
            pooler_outs = self.pooler(cntxt)
            ###############################################

            # to token ####################################
            batch_words = pooler_outs.view([self.batch_size, VOCAB_SIZE])
            (_, batch_ids) = batch_words.max(1)
            ###############################################

            inputs = batch_ids.tolist()
            outputs = torch.cat((outputs, pooler_outs), dim=1)
            out_fvs = torch.cat((out_fvs, dense_fvs), dim=1)
            count = count + 1

        return outputs[:,1:,:], (hn, cn), out_fvs[:,1:,:]


Decoderの LSTMは単方向であり、レイヤー数は Encoderの LSTMと同じになっています。
Decoder は学習時と推論時で動作が異なるため、学習用の train_forward() と推論用の eval_forward() に分けて記述しています。

学習時は教師データの系列の先頭に '[CLS]' トークンを挿入した系列を Decoderに入力し、応答を生成しますが、推論時は教師データは与えられないため、1ステップずつ応答の単語を出力し、出力された単語を次のステップの入力にするという方法になります。

簡単にまとめたものが以下の図になります。
LSTM_learn.png
応答生成では、単語 $ w_{n-1} $ と LSTMの一つ前の時刻の出力ベクトル $h_{n-1}$, $ c_{n-1}$ から、単語 $ w_n $ を生成するというタスクを行います。
この、ベクトル $h_{n-1}$ と $c_{n-1}$ は文脈情報と呼ばれることもあるため、つまり応答生成は、文脈情報と一つ前のステップの単語から次に来る単語を予測するタスクと言えるでしょう。

4.4. Model synthesize

dialog_sys/nn ディレクトリ下に配置します。

encoder_decoder.py
import torch.nn as nn

from .encoder import build_encoder
from .decoder import build_decoder

def build_model(args, bert_model, bert_tokenizer, device):
    model = EncoderDecoder(args, bert_model, bert_tokenizer, device)
    return model

class EncoderDecoder(nn.Module):
    """
    Synthesize models
    """

    def __init__(self, args, bert_model, bert_tokenizer, device):
        super(EncoderDecoder, self).__init__()

        self.encoder = build_encoder(
            num_layers=args.num_layer,
            bert_model=bert_model,
            out_features=args.hidden_size,
            lstm_features=args.lstm_hidden_size,
            bidrectional=args.use_bidirectional,
            enc_len=args.enc_len,
            device=device
        )
        self.decoder = build_decoder(
            num_layers=args.num_layer,
            out_features=args.hidden_size,
            lstm_features=args.lstm_hidden_size,
            bert_tokenizer=bert_tokenizer,
            max_len=args.max_len,
            bert_hsize=args.bert_hsize,
            device=device
        )

    def forward(self, inputs, mode):
        """
        Parameters
        ----------
        inputs : 
            type : list
            shape : [2, batch_size, seq_len]
            [0,:,:] is query.
            [1,:,:] is response.
        mode : 
            type : string
            'train' or 'eval'
        """
        # Encoder #########################################
        hs, h, c = self.encoder(inputs[0])
        ###################################################

        # Decoder #########################################
        outputs, (_, _), out_fvs = self.decoder(
            hs, 
            h,
            c,
            response=inputs[1],
            mode=mode
        )
        ###################################################

        return outputs, out_fvs

4.5. Dataloader

dialog_sys/utils ディレクトリ下に配置します。

dataloader.py
import os
import re
import math
import linecache

import torch.utils.data as data
from pyknp import Juman
from sklearn.utils import shuffle

from argparse import ArgumentParser


# Juman++
jumanpp = Juman()

def tokenize(text):
    """
    You must input data at a one.
    """
    text = re.sub('\n', '', text)
    result = jumanpp.analysis(text)
    tokenized_text = [mrph.midasi for mrph in result.mrph_list()]

    return tokenized_text


class idDataLoader():
    def __init__(self, path, args:ArgumentParser):
        self.path = path
        self.all_length = 0
        self.each_length = []
        self.files = os.listdir(self.path)
        self.file_num = len(self.files)

        self.args = args

        self.__len__()

    def __call__(self, file_index, line_index):
        if file_index >= self.file_num:
            raise ValueError('index out of bounds.', file_index, self.file_num)

        result = ['','']
        file = self.files[file_index]
        file_len = self.each_length[file_index]

        if line_index > file_len:
            raise ValueError('index out of bounds.', line_index, file_len)

        f_path = os.path.join(self.path, file)
        line = linecache.getline(f_path, (line_index+1))
        linecache.clearcache()

        line = re.sub('\n', '', line)
        query_response = line.split(',')
        result[0] = query_response[0]
        result[1] = query_response[1]

        result[0] = result[0].split(' ')
        result[1] = result[1].split(' ')

        result[0] = [int(token_id) for token_id in result[0]]
        result[1] = [int(token_id) for token_id in result[1]]

        return result

    def __len__(self):
        if self.all_length != 0:
            return int(self.all_length/self.args.batch_size + 0.5)

        for file in self.files:
            f_path = os.path.join(self.path, file)
            f = open(f_path, 'r', encoding='utf-8')
            st = f.readline()
            f.close()
            tmp = int(st)
            if tmp < 0:
                tmp = 0
            self.each_length.append(tmp)
            self.all_length = self.all_length + tmp
            

        return math.ceil(self.all_length/self.args.batch_size)

class textDataset(data.Dataset):
    def __init__(self, dataloader, batch_size, make_mode=False, path=None):
        self.dataloader = dataloader
        self.file_num = self.dataloader.file_num
        self.each_length = self.dataloader.each_length
        self.batch_size = batch_size
        self.make_mode = make_mode

        self.path = path

        self.table = []
        self.generate_table()
        if not self.make_mode:
            self.next_epoch()
        self.current_dt = 0

    def __getitem__(self, idx):
        file_index = self.table[idx][0]
        line_index = self.table[idx][1]
        return self.dataloader(file_index, line_index)

    def __len__(self):
        length = math.ceil(len(self.table)/self.batch_size)
        return length

    def __iter__(self):
        return self

    def __next__(self):
        if self.current_dt == self.__len__():
            self.next_epoch()
            raise StopIteration()
        dts = [[],[]]
        b_size = 0
        if self.current_dt + 1 == self.__len__():
            b_size = len(self.table) % self.batch_size
        else:
            b_size = self.batch_size

        tmp = []
        for b in range(b_size):
            file_index = self.table[self.current_dt * self.batch_size + b][0]
            line_index = self.table[self.current_dt * self.batch_size + b][1]
            tmp = self.dataloader(file_index, line_index)
            dts[0].append(tmp[0])
            dts[1].append(tmp[1])

        if self.current_dt + 1 == self.__len__():
            dif = self.batch_size - b_size
            for _ in range(dif):
                dts[0].append([0])
                dts[1].append([0])

        self.current_dt += 1

        return dts

    def next_epoch(self):
        self.table = shuffle(self.table)
        self.current_dt = 0

    def generate_table(self):
        for i in range(self.file_num):
            for n in range(self.each_length[i]+1):
                if n==0:
                    continue
                tap = (i,n)
                self.table.append(tap)

先ほど環境構築・事前準備のくだりで述べたデータ取り出しの操作を行います。
textDatasetクラスに定義した __iter__()、 __len__()、 __next__() などの関数は tqdm に対応するためのものです。
これらを定義することで tqdm によるプログレスバーを表示し、学習の進捗を監視することができるようになります。

4.6. Scorer

dialog_sys/ngram/scoring ディレクトリ下に配置してください。

score_sentence.py
# coding: utf-8

from argparse import ArgumentParser
import math
import kenlm
import re

class ScoreSentence(object):
    def __init__(self, args:ArgumentParser):
        self.args = args
        self.lmean = args.length_dist_mean
        self.lvar = args.length_dist_var
        self.scale = args.length_dist_scale
        
        nd_exp = lambda x : self.scale * math.exp(-((x-self.lmean)**2)/(2*self.lvar**2))
        self.len_score = lambda l : nd_exp(l)/(math.sqrt(2*math.pi)*self.lvar)
        
        model_path = self.args.used_ngram_model
        self.ngram = kenlm.LanguageModel(model_path)
        
    def __call__(self, parsed_sentences):
        # preprocess
        parsed_sentences = re.sub('[.。.]', '', parsed_sentences)
        parsed_sentences = re.sub('[,、,]', '', parsed_sentences)
        parsed_sentences = re.sub('。$', '', parsed_sentences)
        parsed_sentences = re.sub('、$', '', parsed_sentences)
        parsed_sentences = re.sub('[,.、。,.][,.、。,.]+', '', parsed_sentences)
        
        # dvision sub sentence
        parsed_sentences = re.split('', parsed_sentences)
        parsed_sentences = [re.split('', ps) for ps in parsed_sentences]
        
        scores = []
        sentence = []
        prob_sum = 0.0
        len_all = 0.0
        cut_flag = False
        
        # Show scores and n-gram matches
        for sentence_prd in parsed_sentences:
            for sentence_cnm in sentence_prd:
                words = ['<s>'] + sentence_cnm.split() + ['</s>']
                ngram_scores = self.ngram.full_scores(sentence_cnm)
                scores.append(ngram_scores)
                
                for i, (prob, length, oov) in enumerate(ngram_scores):
                    if prob > self.args.remove_limit and not cut_flag:
                        if words[i] != '<s>' and words[i] != '</s>':
                            sentence.append(words[i])
                            #prob_sum += prob
                            prob_sum += 1/math.pow(2,prob)
                            len_all += 1.
                    else:
                        cut_flag = True
                sentence.append('')
            sentence.pop(-1)
            sentence.append('')
        
        prob_mean = 0.0
        if len_all == 0:
            prob_mean = -1000.0
        else:
            #prob_mean = prob_sum / len_all
            prob_mean = math.log2(len_all / prob_sum)
            prob_mean += self.len_score(len_all) # add length bonus
        
        return prob_mean, sentence

score_sentence.py が応答文のスコアリングを行うコードです。
環境構築・事前準備で準備したKenLMを用います。

ソースコードの、下から5~3行目にあるprob_mean にスコアリング結果が計算されます。
最後から4行目を見ると、「あれ、最後から5行目のコードが正しいのでは?」と思われるかもしれませんが調和平均を計算していますので、4行目が正解です。
調和平均は相加平均や相乗平均より最悪ケースの影響が大きい平均方法なので、一か所の文の不自然さがスコアに与える影響を大きくすることができます。
また、KenLMによるN-gram言語モデルは、出力が対数確率になっているため、それに準じた処理を行っています。

上記のスコアリングに加えて、最後から3行目の部分で、生成された文章の長さに対する加点を行っています。文の長さに対する加点は正規分布によって計算しています。詳しくは self.len_score() のラムダ式の定義部分を参照してください。

4.7. Logger Generator

dialog_sys ディレクトリ下に配置します。

logger_gen.py
from logging import getLogger, StreamHandler, FileHandler, Formatter, INFO
from datetime import datetime


def set_logger(name, rootname="log/main.log"):
    dt_now = datetime.now()
    dt = dt_now.strftime('%Y%m%d_%H%M%S')
    fname = rootname + "." + dt
    logger = getLogger(name)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s"))
    handler2 = FileHandler(filename=fname)
    handler2.setFormatter(Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s"))
    handler1.setLevel(INFO)
    handler2.setLevel(INFO)  #handler2 more than Level.WARN
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    logger.setLevel(INFO)
    return logger

4.8. Trainer (main.py)

dialog_sys ディレクトリ下に配置します。

main.py
from tqdm import tqdm
from argparse import ArgumentParser
from logger_gen import set_logger

from nn.encoder_decoder import build_model
from utils.dataloader import textDataset, idDataLoader

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_sequence
from transformers import BertModel, BertTokenizer, BertConfig, BertForMaskedLM


def add_args(parser: ArgumentParser):
    parser.add_argument('--model-path', default=None, type=str, help='training target model (.pth)')
    parser.add_argument('--epoch-num', default=1, type=int, help='number of epoch')
    parser.add_argument('--batch-size', default=32, type=int, help='batch size of train')
    parser.add_argument('--enc-len', default=128, type=int, help='intput sequence length of encoder')
    parser.add_argument('--max-len', default=62, type=int, help='generated sequence length')
    parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--beta1', default=0.9, type=float, help='one of beta value of AdamW')
    parser.add_argument('--beta2', default=0.98, type=float, help='one of beta value of AdamW')
    parser.add_argument('--weight-decay', default=1e-2, type=float, help='weight decay coefficient of AdamW')
    parser.add_argument('--criterion-reduction', default='mean', type=str, help='reduction which is parameter of criterion')
    parser.add_argument('--adamw-eps', default='1e-8', type=float, help='eps which is parameter of optimizer AdamW')
    parser.add_argument('--disp-progress', default=False, action='store_true', help='display training progress')
    parser.add_argument('--train-data', default='data/train', type=str, help='path of training data')
    parser.add_argument('--eval-data', default='data/eval', type=str, help='path of evaluation data')
    parser.add_argument('--bert-path', default='resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/pytorch_model.bin', type=str, help='path of BERT model')
    parser.add_argument('--bert-config', default='resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/bert_config.json', type=str, help='path of BERT config file')
    parser.add_argument('--bert-vocab', default='resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/vocab.txt', type=str, help='path of BERT torkenizer vocab')
    
    parser.add_argument('--cross-loss-coef', default=0.5, type=float, help='ocoefficient of CrossEntropyLoss')
    parser.add_argument('--mse-loss-coef', default=0.5, type=float, help='coefficient of MSELoss')
    
    parser.add_argument('--num-layer', default=2, type=int, help='LSTM layer number')
    parser.add_argument('--use-bidirectional', default=False, action='store_true', help='use BiDirectioal LSTM')
    parser.add_argument('--bert-hsize', default=768, type=int, help='hidden vector size of BERT')
    parser.add_argument('--hidden-size', default=256, type=int, help='hidden layer size')
    parser.add_argument('--lstm-hidden-size', default=256, type=int, help='hidden layer size')
    
    return parser

class TrainModel(object):
    def __init__(self, args, logger, device='cpu'):
        self.args = args
        self.logger = logger
        self.device = device
        
        # load BERT model
        self.bert_config = BertConfig.from_json_file(self.args.bert_config)
        self.bert_model = BertModel.from_pretrained(self.args.bert_path, config=self.bert_config)
        self.bert_mask = BertForMaskedLM.from_pretrained(self.args.bert_path, config=self.bert_config)
        self.bert_tokenizer = BertTokenizer.from_pretrained(self.args.bert_vocab, do_lower_case=False, do_basic_tokenize=False)
        self.vocab_size = self.bert_tokenizer.vocab_size
        
        self.bert_mask = self.bert_mask.to(device=self.device)
        
        # make dataloader
        dataloader0 = idDataLoader(self.args.train_data, self.args)
        dataset0 = textDataset(dataloader0, batch_size=self.args.batch_size)
        dataloader1 = idDataLoader(self.args.eval_data, self.args)
        dataset1 = textDataset(dataloader1, batch_size=self.args.batch_size)
        
        dataloader_train = dataset0
        dataloader_test = dataset1

        self.dataloader = {'train':dataloader_train, 'eval':dataloader_test}
        
        # building model
        self.model = build_model(
            self.args,
            self.bert_model,
            self.bert_tokenizer,
            self.device
        )
        
        # setting criterion & optimizer
        self.criterion1 = nn.CrossEntropyLoss(
            ignore_index=0
        )
        #self.criterion2 = nn.MSELoss(
        self.criterion2 = nn.KLDivLoss(
            reduction=self.args.criterion_reduction
        )
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=self.args.lr,
            betas=(self.args.beta1, self.args.beta2),
            eps=self.args.adamw_eps,
            weight_decay=self.args.weight_decay
        )
        
        # load model or make new model param file
        try:
            net_dic = torch.load(args.model_path, map_location=device)
            logger.info('loading model ... ')
            self.model.load_state_dict(net_dic)
            logger.info('Done.')
        except Exception:
            logger.info('{} does not exist.'.format(args.model_path))
            logger.info('save model to {}'.format(args.model_path))
            torch.save(self.model.state_dict(), args.model_path)
            
        self.model.to(device=self.device)
        
        # acceleration when forward propagation and loss function calculation methods are constant.
        torch.backends.cudnn.benchmark = True
        # accelerate GPU & save memory
        self.scaler = torch.cuda.amp.GradScaler()

class Trainer(object):
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        
        self.epoch_num = self.args.epoch_num
        self.max_len = self.args.max_len
        self.batch_size = self.args.batch_size

    def __call__(self, tm: TrainModel):

        train_batch_loss = []
        eval_batch_loss = []
        train_batch_acc = []
        eval_batch_acc = []

        for epoch in range(self.epoch_num):
            self.logger.info('Epoch {}/{}'.format(epoch+1, self.epoch_num))

            for phase in ["train", "eval"]:
                if phase == "train":
                    tm.model.train()
                else:
                    tm.model.eval()

                score = self.calc_loss(tm, phase)

                if phase == "train":
                    train_batch_loss.append(score['loss'])
                    train_batch_acc.append(score['acc'])
                else:
                    eval_batch_loss.append(score['loss'])
                    eval_batch_acc.append(score['acc'])
                
        return (train_batch_loss, eval_batch_loss, train_batch_acc, eval_batch_acc)
                
    def calc_loss(self, tm:TrainModel, phase):
        epoch_loss = 0.0
        epoch_corrects = 0
        epoch_data_sum = 0

        turn = 0
        for query, response in tqdm(tm.dataloader[phase]):

            # make teacher data
            packs = pack_sequence([torch.tensor(t, device=device).clone().detach() for t in response], enforce_sorted=False)
            (answer, _) = pad_packed_sequence(
                packs, 
                batch_first=True, 
                padding_value=0.0
            )

            if len(answer[0]) < self.max_len:
                zero_pad = torch.zeros([self.batch_size, (self.max_len - len(answer[0]))], dtype=int, device=device)
                answer = torch.cat((answer, zero_pad), dim=1).contiguous()
            if len(answer[0]) > self.max_len:
                answer = answer[:,:self.max_len].contiguous()

            answer_resp = answer[:,1:].clone().detach() # remove [CLS]
            answer_resp = torch.cat((answer_resp, torch.zeros([self.batch_size, 1], dtype=int, device=device)), dim=1)
            answer_resp = answer_resp.contiguous()

            input_decoder = answer
            input_decoder = input_decoder.tolist()

            answer_mask = torch.ones(answer_resp.shape, dtype=int, device=tm.device) * (answer_resp != 0)
            # learning [PAD]
            answer_mask = answer_mask

            # initialize
            tm.optimizer.zero_grad()

            # Forward propagation
            with torch.set_grad_enabled(phase=="train"), torch.cuda.amp.autocast():
                outputs, feature_vecs = tm.model([query, input_decoder], phase)
                outputs.contiguous()
                
                # for loss1
                vocab_size = tm.vocab_size
                outputs_cri = outputs.contiguous().view([-1, vocab_size]) # [batch_size, seq_len, vocab_size] >> [(batch_size * seq_len), vocab_size]
                outputs_cri = outputs_cri.contiguous()
                answer_resp_cri = answer_resp.view(-1) # [batch_size, seq_len] >> [(batch_size * seq_len)]
                
                # for loss2
                outputs_lsm = F.log_softmax(outputs, dim=-1)
                answer_mask = torch.ones(answer_resp.shape ,dtype=torch.int32, device=device) * (answer_resp != 0)
                pre_answer = tm.bert_mask(answer_resp, attention_mask=answer_mask)
                pre_answer = F.log_softmax(pre_answer[0], dim=-1)
                
                # Loss for cross entropy
                loss1 = tm.criterion1(outputs_cri, answer_resp_cri)
                loss2 = tm.criterion2(outputs_lsm, pre_answer)
                loss = args.cross_loss_coef * loss1 + args.mse_loss_coef * loss2

                # Prediction
                _, preds = torch.max(outputs, 2)
                preds = preds * answer_mask

                # Learning
                if phase=="train":
                    tm.scaler.scale(loss).backward()
                    tm.scaler.step(tm.optimizer)
                    tm.scaler.update()

                # Loss is mean of batch. So total batch loss is bellow.
                epoch_loss += loss.item() * self.batch_size
                # correct
                epoch_corrects += torch.sum(preds == answer_resp) - torch.sum(answer_mask==0)
                epoch_data_sum += torch.sum(answer_mask)
            
            turn += 1

        torch.save(tm.model.state_dict(), self.args.model_path)
        logger.info('complete saving model')

        # display loss & acc for each epoch
        epoch_loss = epoch_loss / len(tm.dataloader[phase])
        epoch_acc = epoch_corrects.double() / epoch_data_sum

        logger.info("{} Loss {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))
            
        score = {}
        score['loss'] = epoch_loss
        score['acc'] = epoch_acc
        
        return score


# setup argument parser & logger
parser = ArgumentParser('This program is trainer for seq2seq with attention model.')
parser = add_args(parser)
args = parser.parse_args()
logger = set_logger("training", "log/train.log")

# GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info("Device: {}".format(device))

trainset = TrainModel(args, logger, device)
trainer = Trainer(args, logger)

# training model
loss = trainer(trainset)

main.pyを実行する際、必要なモデルやパラメータは次のように引数で指定する仕様になっています。

python main.py \
    --model-path params/model.pth \
    --epoch-num 20 \
    --batch-size 64 \
    --enc-len 128 \
    --max-len 64 \
    --lr 1e-4 \
    --beta1 0.9 \
    --beta2 0.98 \
    --criterion-reduction batchmean \
    --adamw-eps 1e-9 \
    --disp-progress \
    --train-data data/train \
    --eval-data data/eval \
    --bert-path resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/pytorch_model.bin \
    --bert-config resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/bert_config.json \
    --bert-vocab resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/vocab.txt \
    --num-layer 2 \
    --use-bidirectional \
    --cross-loss-coef 1 \
    --mse-loss-coef 1 \
    --bert-hsize 768 \
    --hidden-size 128 \
    --lstm-hidden-size 128

引数の入力は非常に面倒なので、vscode のデバッグ機能を使うと簡単です。
ここでは vscode のデバッグ機能についての説明は、話の本筋から外れるため割愛しますが、便利なのでぜひ導入を検討してみてください。

デバッグで使用する launch.jsonは以下の通りです。

launch.json
{
    // Use IntelliSense to learn about possible attributes.
    // Hover to view descriptions of existing attributes.
    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python: train",
            "type": "python",
            "request": "launch",
            "program": "main.py",
            "console": "integratedTerminal",
            "justMyCode": true,
            "cwd": "/path/to/dialog_sys",
            "args": [
                "--model-path",
                "params/model.pth",
                "--epoch-num",
                "20",
                "--batch-size",
                "64",
                "--enc-len",
                "128",
                "--max-len",
                "64",
                "--lr",
                "1e-4",
                "--beta1",
                "0.9",
                "--beta2",
                "0.98",
                "--criterion-reduction",
                "batchmean",
                "--adamw-eps",
                "1e-9",
                "--disp-progress",
                "--train-data",
                "data/train",
                "--eval-data",
                "data/eval",
                "--bert-path",
                "resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/pytorch_model.bin",
                "--bert-config",
                "resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/bert_config.json",
                "--bert-vocab",
                "resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/vocab.txt",
                "--num-layer",
                "2",
                "--use-bidirectional",
                "--cross-loss-coef",
                "1",
                "--mse-loss-coef",
                "1",
                "--bert-hsize",
                "768",
                "--hidden-size",
                "128",
                "--lstm-hidden-size",
                "128"
            ]
        },
        {
            "name": "Python: dialog",
            "type": "python",
            "request": "launch",
            "program": "dialog.py",
            "console": "integratedTerminal",
            "justMyCode": true,
            "cwd": "/path/to/dialog_sys",
            "args": [
                "--model-path",
                "params/model.pth",
                "--enc-len",
                "128",
                "--max-len",
                "64",
                "--bert-path",
                "resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/pytorch_model.bin",
                "--bert-config",
                "resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/bert_config.json",
                "--bert-vocab",
                "resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/vocab.txt",
                "--num-layer",
                "2",
                "--use-bidirectional",
                "--bert-hsize",
                "768",
                "--hidden-size",
                "128",
                "--lstm-hidden-size",
                "128",
                "--used-ngram-model",
                "ngram/scoring/models/ngram.binary",
                "--remove-limit",
                "-5.0",
                "--resp-gen",
                "50",
                "--length-dist-mean",
                "15",
                "--length-dist-var",
                "7",
                "--length-dist-scale",
                "1",
            ]
        }
    ]
}

launch.json の cwd の部分は dialog_sys ディレクトリまでの絶対パスを入力してください。

各引数の意味などを確認したい場合は、次のコマンドを実行して確認してください。

python main.py --help

5. 学習結果

以下のハイパーパラメータで学習を行いました。(launch.json と同じ)

  • エポック数:20
  • 隠れ層のサイズ:128
  • LSTMの隠れベクトルサイズ:128
  • Encoder:双方向LSTM
  • LSTM層:2層
  • 学習率:1e-4
  • AdamW eps:1e-9
  • バッチサイズ:64
  • CrossEntropyLoss係数:1
  • KLDivLoss係数:1

FNF_loss.png
FNF_acc.png
FNF_eval_loss.png
FNF_eval_acc.png

検証データ数が少なかったため、かなりグラフが不安定ですが、学習が進むほど検証データの損失値が上昇していることがわかります。
しかし、検証データに対する精度は若干向上していることが確認できます。
学習データは32万ペア程度でそこまで多くはなく、損失計算は単純に実装しています。そのため、学習が進むほど、学習データに最適化され、検証データに対する損失値も精度も悪化していくと予想していましたが、検証データに対する精度が向上していたのは驚きです。
ただ、やはり損失値に関しては、学習データと検証データで負の相関があるため、他の指標の導入を検討する必要があるでしょう。

6. 実際の対話

気を取り直して、実際に学習したモデルと対話したいと思います。
対話に用いるコードは以下の通りです。
どちらも、dialog_sys ディレクトリ下に配置してください。

BJtokenizer.py
from pyknp import Juman
from transformers import BertTokenizer
import re


# parameters
bert_seq_len = 512

def remove_duplication(tokens):
    """Remove duplication such as 'XYZABABABAB -> XYZAB'

    Args:
        tokens (list[str]): sentence which is tokenized

    Returns:
        list[str]: sentence which is removed duplication
    """
    if len(tokens) < 2:
        return tokens
    
    # remove duplication : 'XYZABABABAB -> XYZAB'
    tokens += ['P', 'P', 'P'] # add dummy element
    
    new_tokens = []
    for i in range(len(tokens)-3):
        source = ' '.join([tokens[i], tokens[i+1]])
        target = ' '.join([tokens[i+2], tokens[i+3]])
        new_tokens.append(tokens[i])
        
        if source == target:
            new_tokens.append(tokens[i+1])
            break
        else:
            continue
        
    tokens = new_tokens
    
    # remove duplication : 'XYZABCABCABC -> XYZABC'
    tokens += ['P', 'P', 'P', 'P', 'P']
    
    new_tokens = []
    for i in range(len(tokens)-5):
        source = ' '.join([tokens[i], tokens[i+1], tokens[i+2]])
        target = ' '.join([tokens[i+3], tokens[i+4], tokens[i+5]])
        new_tokens.append(tokens[i])
        
        if source == target:
            new_tokens.append(tokens[i+1])
            break
        else:
            continue
        
    return new_tokens

class JumanTokenizer():
    def __init__(self):
        self.juman = Juman()

    def tokenize(self, text):
        result = self.juman.analysis(text)
        return [mrph.midasi for mrph in result.mrph_list()]

class BertWithJumanModel(BertTokenizer):
    def __init__(self, args, logger):
        super().__init__(
            args.bert_vocab,
            do_lower_case=False,
            do_basic_tokenize=False
        )
        self.juman_tokenizer = JumanTokenizer()
        self.logger = logger

    def _preprocess_text(self, text):
        return text.replace(" ", "")  # for Juman

    def get_sentence_ids(self, text):
        preprocessed_text = self._preprocess_text(text)
        tokens = self.juman_tokenizer.tokenize(preprocessed_text)
        self.logger.info(tokens)
        bert_tokens = self.tokenize(" ".join(tokens))
        ids = self.convert_tokens_to_ids(["[CLS]"] + bert_tokens[:bert_seq_len] + ["[SEP]"]) # max_seq_len-2

        recv = self.convert_ids_to_tokens(ids)

        return (ids, recv)

    def get_sentence(self, ids):
        tokens = self.convert_ids_to_tokens(ids)
        return self.convert_tokens_to_string(tokens)
    
    def div_sep(self, tokens):
        sentence = ' '.join(tokens)
        sentences = sentence.split('[SEP]')
        sentences = [sent.split() for sent in sentences]
        
        return sentences
    
    def tokens_to_sentence(self, tokens):
        sentence = self.convert_tokens_to_string(tokens)
        sentence = sentence.replace(' ','')
        sentence = re.sub('[,.、。,.][,.、。,.]+', '', sentence)
        return sentence

dialog.py
from argparse import ArgumentParser

import torch
from transformers import BertTokenizer, BertModel, BertConfig

from nn.encoder_decoder import build_model
from ngram.scoring.score_sentence import ScoreSentence
from logger_gen import set_logger
from BJtokenizer import BertWithJumanModel, remove_duplication


def add_args(parser: ArgumentParser):
    parser.add_argument('--model-path', default=None, type=str, help='training target model (.pth)')
    parser.add_argument('--bert-path', default='resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/pytorch_model.bin', type=str, help='path of BERT model')
    parser.add_argument('--bert-config', default='resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/bert_config.json', type=str, help='path of BERT config file')
    parser.add_argument('--bert-vocab', default='resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/vocab.txt', type=str, help='path of BERT torkenizer vocab')
    parser.add_argument('--enc-len', default=128, type=int, help='intput sequence length of encoder')
    parser.add_argument('--max-len', default=62, type=int, help='generated sequence length')
    
    parser.add_argument('--num-layer', default=2, type=int, help='LSTM layer number')
    parser.add_argument('--use-bidirectional', default=False, action='store_true', help='use BiDirectioal LSTM')
    parser.add_argument('--bert-hsize', default=768, type=int, help='hidden vector size of BERT')
    parser.add_argument('--hidden-size', default=256, type=int, help='hidden layer size')
    parser.add_argument('--lstm-hidden-size', default=256, type=int, help='hidden layer size')
    
    parser.add_argument('--used-ngram-model', default='scoring/models/sample.binary', type=str, help='n-gram model for KenLM scoring')
    parser.add_argument('--remove-limit', default=-3.0, type=float, help='cut model\'s response by n-gram scoring')
    parser.add_argument('--resp-gen', default=30, type=int, help='number of times to generate response sentences')
    
    parser.add_argument('--length-dist-mean', default=10, type=int, help='mean of length that generated sentence')
    parser.add_argument('--length-dist-var', default=5, type=int, help='variance of length that generated sentence')
    parser.add_argument('--length-dist-scale', default=1, type=int, help='variance of length that generated sentence')
    
    return parser

# setup argument parser & logger
parser = ArgumentParser('This program is dialogue sysytem.')
parser = add_args(parser)
args = parser.parse_args()
logger = set_logger("dialog", "log/dialog.log")

# GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info("Device : {}".format(device))

# load BERT model
bert_config = BertConfig.from_json_file(args.bert_config)
bert_model = BertModel.from_pretrained(args.bert_path, config=bert_config)
bert_tokenizer = BertTokenizer.from_pretrained(args.bert_vocab, do_lower_case=False, do_basic_tokenize=False)
vocab_size = bert_tokenizer.vocab_size

model = build_model(
    args=args,
    bert_model=bert_model,
    bert_tokenizer=bert_tokenizer,
    device=device
)
model.to(device=device)

# tokenizer
bert_juman = BertWithJumanModel(args, logger)
# N-gram scoring
score = ScoreSentence(args=args)

# load model & set parameters
try:
    net_dic = torch.load(args.model_path, map_location=device)
    logger.info('loading model ... ')
    model.load_state_dict(net_dic)
    logger.info('Done.')
except Exception as e:
    logger.info('{} does not exist.'.format(args.model_path))

# Use as is train mode to generate various responses
model.train()

# processing of dialog
while True:
    input_s = input('>>')
    (ids, recv) = bert_juman.get_sentence_ids(input_s)
    logger.info(ids)
    logger.info(recv)
    
    # create model input
    querys = [ids] * args.resp_gen
    answrs = [[0]] * args.resp_gen
    model_inputs = [querys, answrs]
    
    # inference
    outs, _ = model(model_inputs, "eval")
    outs = outs.contiguous()
    _, preds = torch.max(outs, 2)

    responses = []
    res_tokens = []
    for pred in preds:
        
        # get string of response
        resp = bert_juman.get_sentence(pred.tolist())
        res_tokens.append(resp)
        
        # get tokens of response
        tokens = bert_juman.convert_ids_to_tokens(pred)
        
        # remove duplication & scoring
        sentence = remove_duplication(tokens)
        sentence = bert_juman.div_sep(sentence)[0]
        scores, sentence = score(' '.join(sentence))
        logger.info('cand: {}, {}'.format(sentence, scores))
        
        responses.append([sentence, scores])
        
    # choice best score
    max_score = -100000.0
    max_ind = 0
    best_sentence = []
    for i, response in enumerate(responses):
        sentence, sent_score = response
        if sent_score > max_score:
            max_score = sent_score
            max_ind = i
            best_sentence = sentence
    logger.info('Before (duplication): {}'.format(res_tokens[max_ind])) # Before removing duplication
    logger.info('After (duplication): {}'.format(responses[max_ind])) # After removing duplication
    logger.info('sys: {}'.format(bert_juman.tokens_to_sentence(best_sentence)))

dialog.pyでは、応答候補をコマンドライン引数で指定された回数だけ生成し、生成した応答候補をN-gram言語モデルでスコアリングします。
また、応答候補にバリエーションを持たせるために、モデルは推論モードではなく、学習モードで動作させています。
推論モードでは、Dropoutレイヤの動作が一定になってしまいランダム性がなくなるためです。

また、BJtokenizer.py で実装した remove_duplication() は、Seq2Seqモデルが発生させがちな、同じ単語列の繰り返しを除去するための関数です。2単語の繰り返しと、3単語の繰り返しを除去することが可能です。

対話を行う際には以下のように dialog.pyを実行します。

python dialog.py \
    --model-path params/mode.pth \
    --enc-len 128 \
    --max-len 64 \
    --bert-path resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/pytorch_model.bin \
    --bert-config resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/bert_config.json \
    --bert-vocab resource/bert/Japanese_L-12_H-768_A-12_E-30_BPE_WWM/vocab.txt \
    --num-layer 2 \
    --use-bidirectional \
    --bert-hsize 768 \
    --hidden-size 128 \
    --lstm-hidden-size 128 \
    --used-ngram-model ngram/scoring/models/ngram.binary \
    --remove-limit -5.0 \
    --resp-gen 50 \
    --length-dist-mean 15 \
    --length-dist-var 7 \
    --length-dist-scale 1 

こちらのコードも引数が多いので、vscodeのデバッグ機能の使用を推奨します。
先ほど示した launch.json に dialog.py のコンフィグレーションも記述してありますのでそちらをお使いください。

6.1. 対話

>>おはようございます
2022-08-07 00:30:00,176 | INFO | dialog | ['おはよう', 'ございます']
2022-08-07 00:30:00,176 | INFO | dialog | [2, 20401, 26296, 16613, 3]
2022-08-07 00:30:00,177 | INFO | dialog | ['[CLS]', 'おはよう', 'ござい', '##ます', '[SEP]']
2022-08-07 00:30:00,338 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'ましょう', '。'], -1.5390460212279398
2022-08-07 00:30:00,340 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'みた', '##いで', '##す', 'ね', '。'], -0.8311027577638342
2022-08-07 00:30:00,341 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'ました', '。', 'また', '一', '週間', '頑', '##張り', 'ましょう', '。'], -1.0192636977063367
2022-08-07 00:30:00,342 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', '。'], -0.9307650350838713
2022-08-07 00:30:00,343 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', '。'], -0.9307650350838713
2022-08-07 00:30:00,344 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,345 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', '。'], -0.9307650350838713
2022-08-07 00:30:00,346 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'ので', '体調', 'に', '気', 'を', 'つけて', 'お', '過ごし', 'ください', '。'], -0.9675886409728524
2022-08-07 00:30:00,347 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,348 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'と', 'は', '。', 'また', '寒', '##く', 'なる', 'と', '良い', 'です', 'ね', '。', 'また', '寒', '##く', 'なる', 'ので', '体調', 'に', '気', 'を', 'つけて', 'お', '過ごし', 'ください', 'ね', '。'], -1.1559675664573703
2022-08-07 00:30:00,349 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'ました', 'ね', '。', 'また', '、', 'よろ', '##しく', 'お', '願い', 'し', 'ます', '。'], -0.6966054774457718
2022-08-07 00:30:00,350 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'ので', '体調', 'に', '気', 'を', 'つけて', '過ごし', 'ましょう', 'ね', '。'], -1.0263286904042204
2022-08-07 00:30:00,351 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '、', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8167832432378189
2022-08-07 00:30:00,352 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', '。'], -0.9307650350838713
2022-08-07 00:30:00,353 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'ました', '。'], -1.0285498902465497
2022-08-07 00:30:00,354 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '1', '週間', '始まり', 'ました', '。'], -1.2970063627483002
2022-08-07 00:30:00,355 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '今日', 'から', 'また', '頑', '##張り', 'ます', '。'], -1.017103444330368
2022-08-07 00:30:00,357 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'ので', '、', '体調', '管理', 'は', '気', 'を', 'つけて', 'ね', '。'], -1.3521428182655357
2022-08-07 00:30:00,358 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,359 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'ので', '、', 'また', '寒', '##く', 'なる', 'みた', '##いで', '##す', 'ね', '。'], -0.9943853372248633
2022-08-07 00:30:00,360 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'ので', '、', 'また', '寒', '##く', 'なる', 'みた', '##いで', '##す', 'ね', '。'], -0.9943853372248633
2022-08-07 00:30:00,361 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,362 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', '。'], -0.9307650350838713
2022-08-07 00:30:00,363 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'から', '体調', 'が', '気', 'を', 'つけて', 'ね', '。'], -1.3306418438806813
2022-08-07 00:30:00,364 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', '。'], -0.9307650350838713
2022-08-07 00:30:00,365 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', '。'], -0.9307650350838713
2022-08-07 00:30:00,366 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', '今日', 'から', 'また', '1', '週間', '頑', '##張り', 'ましょう', '。'], -0.5892339273855319
2022-08-07 00:30:00,367 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'みた', '##いで', '##す', 'ね', '。'], -0.8311027577638342
2022-08-07 00:30:00,368 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'みた', '##いで', '##す', 'ね', '。'], -0.8311027577638342
2022-08-07 00:30:00,369 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '。'], -1.0449508894977235
2022-08-07 00:30:00,370 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.9198907347992531
2022-08-07 00:30:00,371 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', 'また', '寒', '##く', 'なる', 'と', '。'], -1.6968229345808978
2022-08-07 00:30:00,372 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'みた', '##いで', '##す', 'ね', '。'], -0.8311027577638342
2022-08-07 00:30:00,373 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,375 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'から', 'また', '寒', '##く', 'なる', 'ので', '体調', 'に', '気', 'を', 'つけて', 'ね', '。'], -1.078991015065561
2022-08-07 00:30:00,376 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,377 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'みた', '##いで', '##す', 'ね', '。'], -0.8311027577638342
2022-08-07 00:30:00,378 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'ので', '、', 'また', '寒', '##く', 'なる', 'ので', '体調', 'に', '気', 'を', 'つけて', 'ください', 'ね', '。'], -1.0855622981239614
2022-08-07 00:30:00,379 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', 'また', '寒', 'さ', 'に', '。'], -1.9063656110538225
2022-08-07 00:30:00,380 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'ので', '、', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -1.0222690919643314
2022-08-07 00:30:00,381 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', '。'], -0.9307650350838713
2022-08-07 00:30:00,382 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,383 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,384 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'ました', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -1.0059363044132827
2022-08-07 00:30:00,385 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,386 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'ので', '、', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -1.0222690919643314
2022-08-07 00:30:00,387 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,388 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なる', 'から', 'また', '寒', '##く', 'なる', 'ので', '体調', 'に', '気', 'を', 'つけて', '。'], -1.1024721344772854
2022-08-07 00:30:00,389 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', 'また', '寒', '##く', 'なり', 'そう', '##です', 'ね', '。'], -0.8654009120203885
2022-08-07 00:30:00,390 | INFO | dialog | cand: ['おはよう', 'ござい', '##ます', '。', '今日', 'は', '寒い', 'です', 'ね', '。', 'また', '寒', '##く', 'なる', 'と', '良い', 'です', 'ね', '。'], -0.9382374299182704
2022-08-07 00:30:00,390 | INFO | dialog | Before (duplication): おはよう ございます 。 今日 から また 1 週間 頑張り ましょう 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP] 。 [SEP]
2022-08-07 00:30:00,391 | INFO | dialog | After (duplication): [['おはよう', 'ござい', '##ます', '。', '今日', 'から', 'また', '1', '週間', '頑', '##張り', 'ましょう', '。'], -0.5892339273855319]
2022-08-07 00:30:00,391 | INFO | dialog | sys: おはようございます。今日からまた1週間頑張りましょう。

上は「おはようございます」と入力した際の出力結果です。

結果ログの出力形式は出力順に、次のようになっています。

  • ユーザ入力(>> が付いているもの)
  • ID化形態素系列
  • 形態素系列
  • 応答候補とそのスコア(cand: と付いているもの)
  • remove_duplication の効果の確認(duplication と付いているもの)
  • 対話システムの最終的な応答(sys: と付いているもの)

出力ログを確認すると、remove_duplication() がしっかりと機能していることが確認できます。
また、score_sentence.py のスコアリング方法を調整する際は、応答候補とそのスコアを参考にしながら改良するといいでしょう。

他の対話結果は、ユーザ入力とシステムの最終的な応答のみに絞って紹介します。
右がユーザの発話、左が発話に対するシステムの返答です。

User System
今日のバイト疲れた お疲れ様です。
やっと記事が書けたよ おめでとうございます。
また飲みに行こう! また行きたいです。
今日は寒かった おはようございます。寒いですね。寒いですね。
今日は暑かった 今日は、お疲れ様です。
名古屋までサイクリングした うん。
機械学習難しい そうですね。

コーパスが少なかったという理由もあるのでしょうが、「名古屋までサイクリング」や「機械学習」などの、具体的・専門的な話を振ると返事がそっけないですね...
単純にコーパスの種類・量を増やす、各話題に関するコーパス(Wikipedia記事等)を用いて学習するなどの対処法を検討する必要がありそうです。

7. まとめ

今回は「BERTを使った Attention付き Seq2Seqモデル」を紹介しました。
今回作成したソースコードはMITライセンスでGitHubに公開しています。

ぜひ改良などを施して、人の良き友人となってくれる対話システムを皆さんの手で創り出してください。

8. 参考文献

Qiita記事

  1. 実践PyTorch
  2. 【深層学習】図で理解するAttention機構
  3. 作って理解する Transformer / Attention
  4. Visual Studio Code でPythonファイルをデバッグする方法

Webページ

  1. BERT日本語Pretrainedモデル - 黒橋・褚・村脇研究室
  2. argparse --- コマンドラインオプション、引数、サブコマンドのパーサー
  3. ナレッジは学ばず遊んで覚える。「Word2Vec」で遊んでみた
9
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?