Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

Chainerを使ってSeq2Seqモデルによる簡易チャットボットを作ってみた。

More than 3 years have passed since last update.

はじめに

  • 結構前に作ってはいたんですが、投稿するの忘れていました。
  • (当時)最新のバージョン2だ(^o^)と思っていましたが、つくり終えるころにバージョン3が発表されてました...。
  • ネットで転がっているのはすごい人達が汎用性高く書いて有りすぎるので解読できなかったのですが、シンプルに実装できて改めてきちんと理解できたような気がします。

参考

環境

以下、主要なものはHomebrewからインストールしています

# OS/ソフトウェア/ライブラリ バージョン
1 Mac OS X EI Capitan
2 Python 3.6.3
3 Mecab 0.996
4 Chainer 2.0

その他のライブラリはpipからインストールしています

  • Numpy
  • argparseなど

学習データの作成

データ元

データの加工

  1. JSONファイルから必要な部分だけを抜き取ります。
  2. インプットとアウトプットをrequestとresponseに分離します。
  3. Mecabを使ってわかち書きにします。
generate_training_data.py
#!/usr/local/bin/python3
# -*- coding: utf-8 -*-

import argparse
import os
import json
import MeCab

def loadJson(dir_path, f):

    file_path = dir_path + '/' + f
    json_file = open(file_path, 'r')
    json_data = json.load(json_file)
    json_file.close()

    return json_data

def saveText(data, message, file_path):

    tagger = MeCab.Tagger("-Owakati")
    tagger.parse("")

    for i in range(len(data['turns'])):
        if message == "U" and data['turns'][i]['speaker'] == message:
            text = data['turns'][i]['utterance'] + "\n"
        elif message == "S" and data['turns'][i]['speaker'] == message and i != 0:
            text = data['turns'][i]['utterance'] + "\n"
        else:
            continue

        wakati_text = tagger.parse(text)
        file_path.write(wakati_text)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--dir_path', default='/Users/neriai/Downloads/DBDC2_dev/DIT')
    parser.add_argument('--message', required=True)
    args = parser.parse_args()

    if args.message == "U":
        file_path = open('request.txt', 'w')
    elif args.message == "S":
        file_path = open('response.txt', 'w')

    for dir_path, dirs, files in os.walk(args.dir_path):
        for f in files:

            json_data = loadJson(dir_path, f)

            saveText(json_data, args.message, file_path)

    file_path.close()

if __name__ == '__main__':

    main()
repuest.txt
あなた が 、 JR 東日本 と 愛媛 地方 の 四国 で 行っ て いる サイクリングガイド と は 、 どの よう な もの です か ? 。 
みかん を 前面 に 押し出し た JR 西日本 が 販売 し て いる JR 東日本 限定 の 四国 みかん 果汁 を 使用 し た 台風 19 号 の よう な 商品 に は 違和感 を 感じ ます が 、 気 に なる 商品 で は あり ます ね 。 
JR 東日本 と 四国 地方 で 、 愛媛 以外 に 、 あなた が 観光 し て い て 特に 楽しい と 思う 地域 は どこ です か ? 。 
讃岐 うどん の 北海道 日本 ハムファイターズ で 、 おすすめ の メニュー や トッピング は あり ます か ? 。 
私 が 今 まで 旅行 を し た 中 で 、 思い出 に 残っ て いる の は 、 香川 県 へ 旅行 に 行っ た 際 に 食べ た 讃岐 うどん が とても 美味しく て 感激 し た の を 覚え て ます 。 
response.txt
台風 すごかっ た です ね 。 大丈夫 でし た ? 
そんな こと し て ませ ん よ 。 どうして そう 思っ た の です か ? 
みかん 好き です か ? 
讃岐 で うどん 食べる の 好き です 。 
思いつき ませ ん 。 あなた の おすすめ は なん です か ? 

学習の実行

  • 今回、そこまで記述量がないため1ファイルにしました。
learning.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import numpy as np

from chainer import Chain, Variable, optimizers, serializers

import chainer.links as L
import chainer.functions as F

class seq2seq(Chain):

    def __init__(
        self,
        request_vocabularies_length,
        response_vocabularies_length,
        embedding_vector_size,
    ):

        super(seq2seq, self).__init__(
            embedx = L.EmbedID(request_vocabularies_length, embedding_vector_size),
            embedy = L.EmbedID(response_vocabularies_length, embedding_vector_size),
            H = L.LSTM(embedding_vector_size, embedding_vector_size),
            W = L.Linear(embedding_vector_size, response_vocabularies_length),
        )

    def __call__(
        self,
        reverse_request_line,
        response_line,
        request_vocabularies,
        response_vocabularies
    ):

        self.H.reset_state()

        for i in range(len(reverse_request_line)):
            word_id = request_vocabularies[reverse_request_line[i]]
            x_k = self.embedx(Variable(np.array([word_id], dtype=np.int32)))
            h = self.H(x_k)

        x_k = self.embedx(Variable(np.array([request_vocabularies['<eos>']], dtype=np.int32)))
        h = self.H(x_k)

        tx = Variable(np.array([response_vocabularies[response_line[0]]], dtype=np.int32))
        accum_loss = F.softmax_cross_entropy(self.W(h), tx)
        accum_acc = F.accuracy(self.W(h), tx)

        for i in range(len(response_line)):
            word_id = response_vocabularies[response_line[i]]
            x_k = self.embedy(Variable(np.array([word_id], dtype=np.int32)))

            if (i == len(response_line) - 1):
                next_word_id = response_vocabularies['<eos>']
            else:
                next_word_id = response_vocabularies[response_line[i + 1]]

            tx = Variable(np.array([next_word_id], dtype=np.int32))
            h = self.H(x_k)

            loss = F.softmax_cross_entropy(self.W(h), tx)
            accum_loss += loss

            acc = F.accuracy(self.W(h), tx)
            accum_acc += acc

        return accum_loss, accum_acc

def main(epochs, request_file, response_file):

    request_vocabularies = {}
    request_lines = open(request_file).read().split('\n')

    for i in range(len(request_lines)):
        line = request_lines[i].split()

        for word in line:
            if word not in request_vocabularies:
                request_vocabularies[word] = len(request_vocabularies)

    request_vocabularies['<eos>'] = len(request_vocabularies)
    request_vocabularies_length = len(request_vocabularies)

    response_vocabularies = {}
    response_lines = open(response_file).read().split('\n')

    for i in range(len(response_lines)):
        line = response_lines[i].split()

        for word in line:
            if word not in response_vocabularies:
                response_vocabularies[word] = len(response_vocabularies)

    response_vocabularies['<eos>'] = len(response_vocabularies)
    response_vocabularies_length = len(response_vocabularies)

    embedding_vector_size = 100

    model = seq2seq(
        request_vocabularies_length,
        response_vocabularies_length,
        embedding_vector_size,
    )

    optimizer = optimizers.Adam()
    optimizer.setup(model)

    for epoch in range(epochs):
        for i in range(len(request_lines)-1):

            request_line = request_lines[i].split()
            reverse_request_line = request_line[::-1]

            response_line = response_lines[i].split()

            model.H.reset_state()
            model.cleargrads()

            loss, acc = model(
                reverse_request_line,
                response_line,
                request_vocabularies,
                response_vocabularies
            )

            loss.backward()
            loss.unchain_backward()

            optimizer.update()

        print('epoch: {} loss: {} accuracy: {}'.format(str(epoch), loss, acc.data))

        outfile = "seq2seq-" + str(epoch) + ".model"
        serializers.save_npz(outfile, model)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', default=100)
    parser.add_argument('--request_file', default='request.txt')
    parser.add_argument('--response_file', default='response.txt')
    args = parser.parse_args()

    main(args.epochs, args.request_file, args.response_file)

対話の検証

  • 見事に破綻したチャットボットになりましたヽ( ´▽)ノ
test.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import sys
import re
import numpy as np
import MeCab
from chainer import Chain, Variable, serializers
import chainer.functions as F
import chainer.links as L

def extract_words(text):

    tagger = MeCab.Tagger("-Owakati")
    tagger.parse("")
    wakati = tagger.parse(text)

    words = []
    ws = re.compile(" ")

    for word in ws.split(wakati):
        words.append(word)

    return words

class seq2seq(Chain):

    def __init__(
        self,
        request_vocabularies_length,
        response_vocabularies_length,
        embedding_vector_size,
    ):

        super(seq2seq, self).__init__(
            embedx = L.EmbedID(request_vocabularies_length, embedding_vector_size),
            embedy = L.EmbedID(response_vocabularies_length, embedding_vector_size),
            H = L.LSTM(embedding_vector_size, embedding_vector_size),
            W = L.Linear(embedding_vector_size, response_vocabularies_length),
        )

    def __call__(
        self,
        reverse_request_line,
        response_line,
        request_vocabularies,
        response_vocabularies
    ):

        self.H.reset_state()

        for i in range(len(reverse_request_line)):
            word_id = request_vocabularies[reverse_request_line[i]]
            x_k = self.embedx(Variable(np.array([word_id], dtype=np.int32)))
            h = self.H(x_k)

        x_k = self.embedx(Variable(np.array([request_vocabularies['<eos>']], dtype=np.int32)))
        h = self.H(x_k)

        tx = Variable(np.array([response_vocabularies[response_line[0]]], dtype=np.int32))
        accum_loss = F.softmax_cross_entropy(self.W(h), tx)
        accum_acc = F.accuracy(self.W(h), tx)

        for i in range(len(response_line)):
            word_id = response_vocabularies[response_line[i]]
            x_k = self.embedy(Variable(np.array([word_id], dtype=np.int32)))

            if (i == len(response_line) - 1):
                next_word_id = response_vocabularies['<eos>']
            else:
                next_word_id = response_vocabularies[response_line[i + 1]]

            tx = Variable(np.array([next_word_id], dtype=np.int32))
            h = self.H(x_k)

            loss = F.softmax_cross_entropy(self.W(h), tx)
            accum_loss += loss

            acc = F.accuracy(self.W(h), tx)
            accum_acc += acc

        return accum_loss, accum_acc

def mt(model, words, id2wd, request_vocabularies, response_vocabularies):

    for i in range(len(words)):

        if words[i] not in request_vocabularies:
            print("None Word!!", words[i])
            sys.exit(0)

        word_id = request_vocabularies[words[i]]
        x_k = model.embedx(Variable(np.array([word_id], dtype=np.int32)))
        h = model.H(x_k)

    x_k = model.embedx(Variable(np.array([request_vocabularies['<eos>']], dtype=np.int32)))
    h = model.H(x_k)
    word_id = np.argmax(F.softmax(model.W(h)).data[0])

    output = ''

    if word_id in id2wd:
        output = output + id2wd[word_id]
    else:
        output = output + word_id
    loop = 0

    while (word_id != response_vocabularies['<eos>']):
        x_k = model.embedy(Variable(np.array([word_id], dtype=np.int32)))
        h = model.H(x_k)
        word_id = np.argmax(F.softmax(model.W(h)).data[0])

        if word_id in id2wd:
            output = output + id2wd[word_id]
        else:
            output = output + word_id
        loop += 1

    print(output)

def constructVocabularies(corpus, message):

    vocabularies = {}
    id2wd = {}

    lines = open(corpus).read().split('\n')

    for i in range(len(lines)):
        line = lines[i].split()

        for word in line:
            if word not in vocabularies:
                if message == "U":
                    vocabularies[word] = len(vocabularies)
                elif message == "R":
                    id2wd[len(vocabularies)] = word
                    vocabularies[word] = len(vocabularies)

    if message == "U":
        vocabularies['<eos>'] = len(vocabularies)
        vocabularies_length = len(vocabularies)
        return vocabularies, vocabularies_length
    elif message == "R":
        id2wd[len(vocabularies)] = '<eos>'
        vocabularies['<eos>'] = len(vocabularies)
        vocabularies_length = len(vocabularies)
        return vocabularies, vocabularies_length, id2wd

def main(request_file, response_file, model_file):

    request_vocabularies, request_length = constructVocabularies(request_file, message="U")
    response_vocabularies, response_length, id2wd = constructVocabularies(response_file, message="R")

    embedding_vector_size = 100
    model = seq2seq(request_length, response_length, embedding_vector_size)
    serializers.load_npz(model_file, model)

    while True:
        utterance = input()

        if utterance == "exit":
            print("Bye!!")
            sys.exit(0)

        words = extract_words(utterance)
        words.remove('\n')

        words = words[::-1]

        mt(model, words, id2wd, request_vocabularies, response_vocabularies)

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--request_file', default='request.txt')
    parser.add_argument('--response_file', default='response.txt')
    parser.add_argument('--model_file', default='seq2seq-50.model')
    args = parser.parse_args()

    main(args.request_file, args.response_file, args.model_file)
neriai
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away