Edited at

PyTorchではじめるチャットボット

More than 1 year has passed since last update.


対象読者


  • チャットボットを作成したい方

  • QRNNがどの分野で適用可能か知りたい方

  • PyTorchを試してみたい方


はじめに

Chainerで学習した対話用のボットをSlackで使用+Twitterから学習データを取得してファインチューニングがおかげ様で人気な記事になっているのでPyTorchを使用して同様のチャットボットを作成しました。

チャット例

Screenshot from 2017-12-03 13:19:58.png


システム構成

システム構成の違いは事前学習なしで学習している点です。それでも動作したQRNNはかなりのパフォーマンスを発揮できることが期待できます。

以前の構成

image.png

今回の構成

Screenshot from 2017-12-03 13:51:05.png


手法


QRNN


  • RNN: 過去の情報を利用して学習するため長い文章生成やセンサーデータなどに対して有効に働くが構造上、過去の情報を使用するため並列化処理が難しい。

  • CNN: RNNに比べて並列化が容易であり、自然言語の分野にもいくつか適用されてきているがRNNと違い過去の情報を考慮した長い文章生成やセンサーデータは苦手

両者の良い所取りをしたのがQRNNです。

QRNNに関しては良記事があるのでそちらをご参照下さい。

LSTMを超える期待の新星、QRNN


QRNN Encoder Decoder

QRNNはRNNの擬似モデルのため、今までRNNを使用してた領域にも使用可能です。Attentionを使用しないQRNN Encoder DecoderをPyTorchで実装しました。

Encoder部分の処理とDecoder部分の処理は通常のQRNNと同一ですがDecoderにEncoderの内容を反映させる必要があります。

ではどのように反映させるかですがEncoderの最終の隠れ層をDecoderの全出力と畳み込み処理をかけることで反映しています。

ここでVはエンコーダーの最終の隠れ層をデコーダーの次元に合わせるのに使用しています。

具体的な数式は下記です。

Z^l = tanh(W^{l}_z*X^l+V^{l}_z\tilde{h}^{l}_T) \\

F^l = \sigma(W^{l}_f*X^l+V^{l}_f\tilde{h}^{l}_T) \\
O^l = \sigma(W^{l}_o*X^l+V^{l}_o\tilde{h}^{l}_T) \\

コードで追ってみると

まずEncoder部分で最終層の状態を保持します。


c_last = c[range(len(inputs)), (input_len - 1).data, :]
h_last = h[range(len(inputs)), (input_len - 1).data, :]
cell_states.append(c_last)
hidden_states.append((h_last, h))

https://github.com/SnowMasaya/pytorch-chat-bot/blob/86758b7047aaaa704e825e39698250704bd41958/src/models/encoder_qrnn.py#L43

Decoderで保持された状態を受け取ります。


c, h = layer(h, state, memory)

https://github.com/SnowMasaya/pytorch-chat-bot/blob/86758b7047aaaa704e825e39698250704bd41958/src/models/decoder_qrnn.py#L54

状態の情報を受け取り畳み込み処理を行います。


(conv_memory, attention_memory) = memory

Z, F, O = self._conv_step(inputs, conv_memory)

https://github.com/SnowMasaya/pytorch-chat-bot/blob/86758b7047aaaa704e825e39698250704bd41958/src/models/qrnn_layer.py#L70

下記で取得したメモリーがあれば上式のVの処理で同一空間にマップして足しています。


if memory is not None:
gates = gates + self.conv_linear(memory).unsqueeze(1)


モデル構成

Tensorboardで出力したかったのですが時間がかかりすぎて断念しました。

Screenshot from 2017-12-03 17:03:06.png


結果

学習例

Attentionと比較しています。学習結果に差がありすぎるためAttentionの実装をチェックする必要がありますが圧倒的な差でQRNNのロスが少なくなっています。

時間の都合上、バリデーションデータ、テストデータを用意した評価は行っていません。

また正則化処理であるZoneOut、ドロップアウトは使用していないので過学習気味な気がします。

学習データはTwitterのデータを使用した対話文章をユニークな文章428文、学習データとしては不十分なため7704文にかさ増しして学習しました。(同じ文章をコピペして増やしただけです)

各種モデルのパラメータはコードを参照下さい。

茶色がQRNNで水色がAttentionです。

Screenshot from 2017-12-03 17:07:38.png


実際にチャットをしてみたい方へ

デモ動画

pytorch-chat.mov.gif

Slackで対話できるようにしました!!

下記のように入力してもらえると対話可能です。

pytorch: {対話したい内容}


今後のチャットデモの対応


  • データを増やして学習したモデルの導入

  • 正則化処理をいれた場合の比較し性能が良ければ導入

  • QRNN+Attentionの比較し性能が良ければ導入

  • 検索機能の導入

etc


コード

https://github.com/SnowMasaya/pytorch-chat-bot


参考

https://arxiv.org/pdf/1611.01576.pdf

http://pytorch.org/

https://github.com/JayParks/quasi-rnn

http://musyoku.github.io/2017/05/30/Quasi-Recurrent-Neural-Networks/