はじめに
- 結構前に作ってはいたんですが、投稿するの忘れていました。
- (当時)最新のバージョン2だ(^o^)と思っていましたが、つくり終えるころにバージョン3が発表されてました...。
- ネットで転がっているのはすごい人達が汎用性高く書いて有りすぎるので解読できなかったのですが、シンプルに実装できて改めてきちんと理解できたような気がします。
参考
環境
以下、主要なものはHomebrewからインストールしています
# | OS/ソフトウェア/ライブラリ | バージョン |
---|---|---|
1 | Mac OS X | EI Capitan |
2 | Python | 3.6.3 |
3 | Mecab | 0.996 |
4 | Chainer | 2.0 |
その他のライブラリはpipからインストールしています
- Numpy
- argparseなど
学習データの作成
データ元
- おなじみの対話破綻コーパスを利用させていただきました。
データの加工
- JSONファイルから必要な部分だけを抜き取ります。
- インプットとアウトプットをrequestとresponseに分離します。
- Mecabを使ってわかち書きにします。
generate_training_data.py
#!/usr/local/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
import json
import MeCab
def loadJson(dir_path, f):
file_path = dir_path + '/' + f
json_file = open(file_path, 'r')
json_data = json.load(json_file)
json_file.close()
return json_data
def saveText(data, message, file_path):
tagger = MeCab.Tagger("-Owakati")
tagger.parse("")
for i in range(len(data['turns'])):
if message == "U" and data['turns'][i]['speaker'] == message:
text = data['turns'][i]['utterance'] + "\n"
elif message == "S" and data['turns'][i]['speaker'] == message and i != 0:
text = data['turns'][i]['utterance'] + "\n"
else:
continue
wakati_text = tagger.parse(text)
file_path.write(wakati_text)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dir_path', default='/Users/neriai/Downloads/DBDC2_dev/DIT')
parser.add_argument('--message', required=True)
args = parser.parse_args()
if args.message == "U":
file_path = open('request.txt', 'w')
elif args.message == "S":
file_path = open('response.txt', 'w')
for dir_path, dirs, files in os.walk(args.dir_path):
for f in files:
json_data = loadJson(dir_path, f)
saveText(json_data, args.message, file_path)
file_path.close()
if __name__ == '__main__':
main()
repuest.txt
あなた が 、 JR 東日本 と 愛媛 地方 の 四国 で 行っ て いる サイクリングガイド と は 、 どの よう な もの です か ? 。
みかん を 前面 に 押し出し た JR 西日本 が 販売 し て いる JR 東日本 限定 の 四国 みかん 果汁 を 使用 し た 台風 19 号 の よう な 商品 に は 違和感 を 感じ ます が 、 気 に なる 商品 で は あり ます ね 。
JR 東日本 と 四国 地方 で 、 愛媛 以外 に 、 あなた が 観光 し て い て 特に 楽しい と 思う 地域 は どこ です か ? 。
讃岐 うどん の 北海道 日本 ハムファイターズ で 、 おすすめ の メニュー や トッピング は あり ます か ? 。
私 が 今 まで 旅行 を し た 中 で 、 思い出 に 残っ て いる の は 、 香川 県 へ 旅行 に 行っ た 際 に 食べ た 讃岐 うどん が とても 美味しく て 感激 し た の を 覚え て ます 。
response.txt
台風 すごかっ た です ね 。 大丈夫 でし た ?
そんな こと し て ませ ん よ 。 どうして そう 思っ た の です か ?
みかん 好き です か ?
讃岐 で うどん 食べる の 好き です 。
思いつき ませ ん 。 あなた の おすすめ は なん です か ?
学習の実行
- 今回、そこまで記述量がないため1ファイルにしました。
learning.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import numpy as np
from chainer import Chain, Variable, optimizers, serializers
import chainer.links as L
import chainer.functions as F
class seq2seq(Chain):
def __init__(
self,
request_vocabularies_length,
response_vocabularies_length,
embedding_vector_size,
):
super(seq2seq, self).__init__(
embedx = L.EmbedID(request_vocabularies_length, embedding_vector_size),
embedy = L.EmbedID(response_vocabularies_length, embedding_vector_size),
H = L.LSTM(embedding_vector_size, embedding_vector_size),
W = L.Linear(embedding_vector_size, response_vocabularies_length),
)
def __call__(
self,
reverse_request_line,
response_line,
request_vocabularies,
response_vocabularies
):
self.H.reset_state()
for i in range(len(reverse_request_line)):
word_id = request_vocabularies[reverse_request_line[i]]
x_k = self.embedx(Variable(np.array([word_id], dtype=np.int32)))
h = self.H(x_k)
x_k = self.embedx(Variable(np.array([request_vocabularies['<eos>']], dtype=np.int32)))
h = self.H(x_k)
tx = Variable(np.array([response_vocabularies[response_line[0]]], dtype=np.int32))
accum_loss = F.softmax_cross_entropy(self.W(h), tx)
accum_acc = F.accuracy(self.W(h), tx)
for i in range(len(response_line)):
word_id = response_vocabularies[response_line[i]]
x_k = self.embedy(Variable(np.array([word_id], dtype=np.int32)))
if (i == len(response_line) - 1):
next_word_id = response_vocabularies['<eos>']
else:
next_word_id = response_vocabularies[response_line[i + 1]]
tx = Variable(np.array([next_word_id], dtype=np.int32))
h = self.H(x_k)
loss = F.softmax_cross_entropy(self.W(h), tx)
accum_loss += loss
acc = F.accuracy(self.W(h), tx)
accum_acc += acc
return accum_loss, accum_acc
def main(epochs, request_file, response_file):
request_vocabularies = {}
request_lines = open(request_file).read().split('\n')
for i in range(len(request_lines)):
line = request_lines[i].split()
for word in line:
if word not in request_vocabularies:
request_vocabularies[word] = len(request_vocabularies)
request_vocabularies['<eos>'] = len(request_vocabularies)
request_vocabularies_length = len(request_vocabularies)
response_vocabularies = {}
response_lines = open(response_file).read().split('\n')
for i in range(len(response_lines)):
line = response_lines[i].split()
for word in line:
if word not in response_vocabularies:
response_vocabularies[word] = len(response_vocabularies)
response_vocabularies['<eos>'] = len(response_vocabularies)
response_vocabularies_length = len(response_vocabularies)
embedding_vector_size = 100
model = seq2seq(
request_vocabularies_length,
response_vocabularies_length,
embedding_vector_size,
)
optimizer = optimizers.Adam()
optimizer.setup(model)
for epoch in range(epochs):
for i in range(len(request_lines)-1):
request_line = request_lines[i].split()
reverse_request_line = request_line[::-1]
response_line = response_lines[i].split()
model.H.reset_state()
model.cleargrads()
loss, acc = model(
reverse_request_line,
response_line,
request_vocabularies,
response_vocabularies
)
loss.backward()
loss.unchain_backward()
optimizer.update()
print('epoch: {} loss: {} accuracy: {}'.format(str(epoch), loss, acc.data))
outfile = "seq2seq-" + str(epoch) + ".model"
serializers.save_npz(outfile, model)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', default=100)
parser.add_argument('--request_file', default='request.txt')
parser.add_argument('--response_file', default='response.txt')
args = parser.parse_args()
main(args.epochs, args.request_file, args.response_file)
対話の検証
- 見事に破綻したチャットボットになりましたヽ( ´▽)ノ
test.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import sys
import re
import numpy as np
import MeCab
from chainer import Chain, Variable, serializers
import chainer.functions as F
import chainer.links as L
def extract_words(text):
tagger = MeCab.Tagger("-Owakati")
tagger.parse("")
wakati = tagger.parse(text)
words = []
ws = re.compile(" ")
for word in ws.split(wakati):
words.append(word)
return words
class seq2seq(Chain):
def __init__(
self,
request_vocabularies_length,
response_vocabularies_length,
embedding_vector_size,
):
super(seq2seq, self).__init__(
embedx = L.EmbedID(request_vocabularies_length, embedding_vector_size),
embedy = L.EmbedID(response_vocabularies_length, embedding_vector_size),
H = L.LSTM(embedding_vector_size, embedding_vector_size),
W = L.Linear(embedding_vector_size, response_vocabularies_length),
)
def __call__(
self,
reverse_request_line,
response_line,
request_vocabularies,
response_vocabularies
):
self.H.reset_state()
for i in range(len(reverse_request_line)):
word_id = request_vocabularies[reverse_request_line[i]]
x_k = self.embedx(Variable(np.array([word_id], dtype=np.int32)))
h = self.H(x_k)
x_k = self.embedx(Variable(np.array([request_vocabularies['<eos>']], dtype=np.int32)))
h = self.H(x_k)
tx = Variable(np.array([response_vocabularies[response_line[0]]], dtype=np.int32))
accum_loss = F.softmax_cross_entropy(self.W(h), tx)
accum_acc = F.accuracy(self.W(h), tx)
for i in range(len(response_line)):
word_id = response_vocabularies[response_line[i]]
x_k = self.embedy(Variable(np.array([word_id], dtype=np.int32)))
if (i == len(response_line) - 1):
next_word_id = response_vocabularies['<eos>']
else:
next_word_id = response_vocabularies[response_line[i + 1]]
tx = Variable(np.array([next_word_id], dtype=np.int32))
h = self.H(x_k)
loss = F.softmax_cross_entropy(self.W(h), tx)
accum_loss += loss
acc = F.accuracy(self.W(h), tx)
accum_acc += acc
return accum_loss, accum_acc
def mt(model, words, id2wd, request_vocabularies, response_vocabularies):
for i in range(len(words)):
if words[i] not in request_vocabularies:
print("None Word!!", words[i])
sys.exit(0)
word_id = request_vocabularies[words[i]]
x_k = model.embedx(Variable(np.array([word_id], dtype=np.int32)))
h = model.H(x_k)
x_k = model.embedx(Variable(np.array([request_vocabularies['<eos>']], dtype=np.int32)))
h = model.H(x_k)
word_id = np.argmax(F.softmax(model.W(h)).data[0])
output = ''
if word_id in id2wd:
output = output + id2wd[word_id]
else:
output = output + word_id
loop = 0
while (word_id != response_vocabularies['<eos>']):
x_k = model.embedy(Variable(np.array([word_id], dtype=np.int32)))
h = model.H(x_k)
word_id = np.argmax(F.softmax(model.W(h)).data[0])
if word_id in id2wd:
output = output + id2wd[word_id]
else:
output = output + word_id
loop += 1
print(output)
def constructVocabularies(corpus, message):
vocabularies = {}
id2wd = {}
lines = open(corpus).read().split('\n')
for i in range(len(lines)):
line = lines[i].split()
for word in line:
if word not in vocabularies:
if message == "U":
vocabularies[word] = len(vocabularies)
elif message == "R":
id2wd[len(vocabularies)] = word
vocabularies[word] = len(vocabularies)
if message == "U":
vocabularies['<eos>'] = len(vocabularies)
vocabularies_length = len(vocabularies)
return vocabularies, vocabularies_length
elif message == "R":
id2wd[len(vocabularies)] = '<eos>'
vocabularies['<eos>'] = len(vocabularies)
vocabularies_length = len(vocabularies)
return vocabularies, vocabularies_length, id2wd
def main(request_file, response_file, model_file):
request_vocabularies, request_length = constructVocabularies(request_file, message="U")
response_vocabularies, response_length, id2wd = constructVocabularies(response_file, message="R")
embedding_vector_size = 100
model = seq2seq(request_length, response_length, embedding_vector_size)
serializers.load_npz(model_file, model)
while True:
utterance = input()
if utterance == "exit":
print("Bye!!")
sys.exit(0)
words = extract_words(utterance)
words.remove('\n')
words = words[::-1]
mt(model, words, id2wd, request_vocabularies, response_vocabularies)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--request_file', default='request.txt')
parser.add_argument('--response_file', default='response.txt')
parser.add_argument('--model_file', default='seq2seq-50.model')
args = parser.parse_args()
main(args.request_file, args.response_file, args.model_file)