Help us understand the problem. What is going on with this article?

BERTをEncoderとするChatbotの作成

はじめに

夏休みということで、以前から作ってみたかったChatbot(といえるかどうか微妙)を作成したので共有しようと思い投稿しました。
本当にやりたいことは長期依存性を持つ会話モデルの作成なのですが、リソースとデータの観点から厳しいですね。

BERTの採用理由は単純な興味からなります。
そもそもTransformerベースのものでChatbotを作ること自体が間違っているという意見が私の中で出ましたが、最近RoBERTa論文がでたので使いたくなってしまいました。
あと、日本語の学習済みモデルがあったからです。

コードはGithubで公開しています。
コードにはbert_pretrainingというフォルダがありますが、特に使っていません。
会話データでfinetuneした方がいいかなと思ったため作っただけです。

結果

先に結果を載せます。
sample (3).png

入力はその時に考えた言葉を入れました。

データ収集するときにキーワードで収集してたのですが、「やば」で収集したデータが他に比べて多かったため、「やば」が入っていない入力に対しても「やばい」「やばすぎw」「やばたにえん」と返してくる気がします。
あと、「それな」も集めてたので結構返してきます。

この辺は対話系の論文などでも挙げられている多様性の問題そのものだと思います。
頻出トークンが積極的に採用されるためにGreedyで候補を選ぶと当たり障りない返答が返ってくるみたいな。
この辺の研究に関してNAISTの中村研の方が論文出してた気がします。

また、この結果からもわかるように「ありがとう」のデータは集めてないためこのような返答になったと考えています。

今回の実装ではまともな返答をしてくれるかどうかなので、会話ではないですね。
会話をするならばHREDや、過去の履歴も入力とするLSTMモデルなどが候補として挙がりますかね。

正直BERTを使う必要あったかと言われると微妙なところです...。

Tokenizer

Tokenizerにはsentencepieceを使用しました。
これはBERTの学習済みモデルがsentencepieceで学習していたためそのまま流用しました。
公開していただきありがとうございます。

Architecture

モデルのアーキテクチャは単純でpytorch-transformersのBERTに、Vanilla TransformerのDecoderを繋げただけです。
これは脳死で決めました。返事を返してくれればいいやくらいの気持ちです。

Decoderの実装に関してFFNの部分だけは、LinearかConv1dかの問題になりましたがConv1dの方がGPUのメモリ食わない気がしたので(気のせいかもしれない)こちらにしました。
ただ、論文を読む感じLinearを想定している気もしますが...。

モデルのコードは以下のようになります。パラメータもコンストラクタのものを使用しました。
Vanilla Transformerとの違いはd_modelが512->768になっています。
Decoderの頭にLinearをはさんでもよかったですが、BERTに合わせました。

  • bert_model_dir: BERTのpretrainedモデルが入っているディレクトリのパス
  • vocab_size: 語彙数(sos_token, eos_tokenなど含む)
  • h: MultiHeadAttentionのHead数
  • d_model: 埋め込みベクトルのサイズおよび隠れ層の次元
  • d_ff: FFNの次元
  • drop_rate: Dropoutの確立
  • max_len: 1文の最大トークン数

何となくわかっていただければと思います。
詳しくはGithubを参照してください。

encoder_decoder.py
class Model(nn.Module):
    def __init__(self, bert_model_dir, vocab_size=32000, h=8,
                 d_model=768, N=6, d_ff=2048, drop_rate=0.1, max_len=12):
        super(Model, self).__init__()

        # target用のembedding
        self.target_emb = build_embedding(vocab_size, d_model, drop_rate, max_len=max_len)

        # BERT encoder
        self.encoder = Encoder.from_pretrained(bert_model_dir)
        # Freeze encoder
        for param in self.encoder.parameters():
            param.requires_grad = False

        # Transformer decoder
        self.decoder = build_decoder(N, h, d_model, d_ff)

        # 出力を(batch_size, max_len, vocab_size)に変換
        # それにlog_softmaxを適用
        # この確率を基に後で文に戻す
        self.generator = Generator(d_model, vocab_size)

    def forward(self, source, source_mask, target, target_mask):
        x = self.encode(source, source_mask)
        x = self.decode(x, source_mask, target, target_mask)
        x = self.generate(x)
        return x

    def encode(self, source, source_mask):
        return self.encoder(source, attention_mask=source_mask)

    def decode(self, mem, source_mask, target, target_mask):
        return self.decoder(self.target_emb(target), mem, source_mask, target_mask)

    def generate(self, x):
        return self.generator(x)

学習

LossとかOptimizerの設定などはThe Annotated Transformerの実装に倣いました。
なので、LossはKLDivLoss、OptimizerはWarm-upありのAdamです。

Twitterから集めたデータ数十万ペアくらいの小規模なデータで20epochくらい回した結果が上の結果になります。

感想

モデルのアーキテクチャを脳死で考えてしまいましたが、まともな(?)返答をしてくれただけで嬉しいです。
まあ、今回の場合はBERT使わなくても通常のTransformerでも十分だと思いますね。
それに、会話だとRNNベースの方がいいですね。
Transformerで過去の会話内容参照するためにはメモリなどに保存して、それ用にEncodingして入力しないといけないのでそういった面ではRNNの方が扱いやすいように思います。
これに関しては上にも書いたTransformer-xlの再帰機構を応用すれば実現できるのでは?と愚考しています。そう単純な話ではなさそうですけどね。

個人的にTransformerの利点の1つはDecoderの訓練時に時間方向の並列化ができる点だと思っているので、学習速度や学習の安定性からTransformerでのアプローチはしていきたいと考えています。

あと、思い知ったのは学習データの重要性ですね。もっとまともな学習データを集めたいです。いいデータセットがあればコメント欄で教えてください。

今後は連続する会話データで長期依存性を持つモデルを作成していきたいと思います。
若い内に色々なことに挑戦していきたいです。

拙い文章ですが読んでいただきありがとうございました。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away