チャットボットを作るのにも一苦労だったので、備忘録。
に完成品があります。
チャットボットの構成
バックエンド Flask で チャットボットの基幹となる部分の実装を、フロント React でチャットボットの表示をしています。
チャットボットの学習済みのAI には Blender という Facebook 製のものを使いました。多くのの自然学習系のライブラリと同じように、
1:入力した文章をトークンに encode して、
2:結果を encode したトークンから生成し、
3:トークンの結果を decode して文字列の結果を取得
する流れになります。
チャットボットが、前に入力した文字列からその後に続くものを予測するものなので、一回一回の入力は前回までのチャットのログ全てになります(前回までのログも最大の長さがあります)。
React のフロントエンドは、単純に 入力されたチャットを Axios でバックエンドに投げて結果を取得するだけです。簡単ですね。
結果
結果だけ、載せます。
興味があったら、参考にしてください。
・バックエンド(Python, Flask)
sample_blenderbot.py
import torch
from base_model import BaseModel
from transformers import (
BlenderbotSmallForConditionalGeneration,
BlenderbotSmallTokenizer,
BlenderbotForConditionalGeneration,
BlenderbotTokenizer
)
class BlenderBot(BaseModel):
def __init__(self, size, device, max_context_length=1024):
assert size in ["small", "medium", "large"], "model size must be one of ['small', 'medium', 'large']"
if size == "small":
super().__init__("facebook/blenderbot_small-90M")
self.model = BlenderbotForConditionalGeneration.from_pretrained(self.name).to(device)
else:
if size == "medium":
super().__init__("facebook/blenderbot-400M-distill")
elif size == "large":
super().__init__("facebook/blenderbot-1B-distill")
self.model = BlenderbotForConditionalGeneration.from_pretrained(self.name).to(device)
self.tokenizer = BlenderbotTokenizer.from_pretrained(self.name)
self.device = device.lower()
self.max_context_length = max_context_length
self.eos = "</s><s>"
self.history_human = {}
self.history_bot = {}
@torch.no_grad()
def predict(
self,
user_id: str,
text: str,
num_beams: int = 5,
top_k: int = 1,
top_p: float = None,
) -> str:
torch.cuda.empty_cache()
input_ids_list: list = []
num_of_stacked_tokens: int = 0
if user_id not in self.history_human.keys():
self.history_human[user_id] = []
self.history_bot[user_id] = []
user_histories = reversed(self.history_human[user_id])
bot_histories = reversed(self.history_bot[user_id])
for user, bot in zip(user_histories, bot_histories):
user_tokens = self.tokenizer.encode(user, return_tensors="pt")
bot_tokens = self.tokenizer.encode(bot, return_tensors="pt")
num_of_stacked_tokens += user_tokens.shape[-1] + bot_tokens.shape[-1]
if num_of_stacked_tokens <= self.max_context_length:
input_ids_list.append(bot_tokens)
input_ids_list.append(user_tokens)
else:
break
input_ids_list = list(reversed(input_ids_list))
new_input = text + self.eos
input_tokens = self.tokenizer.encode(new_input, return_tensors='pt')
input_ids_list.append(input_tokens)
input_tokens = torch.cat(input_ids_list, dim=-1)
input_tokens = input_tokens.to(self.device)
output_ids = self.model.generate(
input_tokens,
max_length=1024,
num_beams=num_beams,
top_k=top_k,
top_p=top_p,
no_repeat_ngram_size=4
)[0]
next_utterance = self.tokenizer.decode(
output_ids.tolist(),
skip_special_tokens=True
).replace("Ġ", "").replace(" ", "")
print(next_utterance)
self.history_human[user_id].append(text + self.eos)
self.history_bot[user_id].append(next_utterance + self.eos)
return next_utterance
app.py
from flask import Flask, render_template, jsonify
from flask_cors import CORS
from sample_blenderbot import BlenderBot
app = Flask(__name__)
CORS(
app,
support_credentials=True
)
b = BlenderBot(size="medium", device="cpu")
@app.route("/")
def index():
return jsonify({"version": "0.01"})
@app.route("/send/<user_id>/<text>")
def send(user_id, text: str):
splitted_text = text.split("_")
merged_text = " ".join(splitted_text)
_out = b.predict(user_id, merged_text)
return jsonify({"text": _out})
if __name__ == "__main__":
app.run(port=8888)
・フロントエンド(Javascript, React)
Chat.tsx
import react, {useEffect, useState} from "react"
import {useParams} from "react-router-dom"
import axios from "axios"
import "./chat.css"
const Chat = () => {
const initial_chat = [
{text: "Please input words.", user: 0}
]
const initial_user = [
{ id: 0, name: "system" },
{ id: 1, name: "bot" },
{ id: 2, name: "user" }
]
const { id } = useParams<{id: string}>()
const [chats, setChats] = useState(initial_chat)
const [text, setText] = useState<string>("")
const [loading, setLoading] = useState(false)
const onChangeTextForm = (e: any) => {
setText(e.target.value)
}
const onClickSendButton = () => {
setLoading(true)
appendChat(2, text)
setText("")
const splitted_new_text = text.split(" ")
const new_text = splitted_new_text.join("_")
axios(`http://localhost:5000/send/${id}/${new_text}`)
.then((result) => {
setLoading(false)
appendChat(1, result.data.text)
})
}
const appendChat = (user_type: number, text: string) => {
setChats((chats) => {
const copied_chat = JSON.parse(JSON.stringify(chats))
console.log(copied_chat)
copied_chat.push({user: user_type, text: text})
return copied_chat
})
}
if (!id) {
return (
<div>
<p>ユーザーを指定してください</p>
</div>
)
}
return (
<div>
<p>ユーザーID : {id} 様</p>
<div className="chat-result-area">
{ chats.map((chat) => {
return (
<>
<div className={ chat.user === 0
? "system-user chat-content"
: chat.user === 1
? "bot-user chat-content"
: "user-user chat-content"
}>
{chat.text}
</div>
</>
)
})}
</div>
<div className="chat-input-area">
<input
type="text"
value={text}
className="input-form"
onChange={(e) => onChangeTextForm(e)}
/>
<button
className="input-button"
onClick={() => onClickSendButton()}
>
{ loading
? "送信中"
: "送信する"
}
</button>
</div>
</div>
)
}
export default Chat;
参考にしたサイト
Transformer などもそうだが 参考にしたコード見る感じ、韓国強い・・・