はじめに
以前作った掲示板データから新馬戦の予想を試みて惨敗した記事を基に、チャット形式で予想できるアプリを作りました。
以前からstreamlitのChat elements
は触ってみたかったので、アプリにしてみました。
結論から言うと、とても簡単に実装できるのでぜひ使ってほしいと思いました!
ユーザからの入力を受け取る
以下のようにしてユーザからの入力(prompt)を受け取る部分を作ります。
受け取る際にbotの返答も記録し、チャット履歴として返します。
def input_prompt(chats):
"""ユーザからの入力を受け取る
Args:
chats ([dict]): 元チャット履歴
Returns:
[dict]: 更新後のチャット履歴
"""
# ユーザーからの入力を受け取る
prompt = st.chat_input("Say something")
# ユーザーからの入力があれば、チャット履歴に追加
if prompt:
chats.append({"role": "user", "message": prompt})
# ボットの返答を予測
pred = predict_response(prompt, st.session_state["doc_model"],
st.session_state["ml_model"])
pred = round(pred*100)
chats.append({"role": "bot",
"message": f"3着以内に入る確率は{pred}%です"})
return chats
チャットを表示
チャット履歴を表示します。先ほどのchats
をforで回しつつ、ユーザとボットの発言を判定していきます。
それぞれに合ったエレメントで表示することで、デザイン的にアイコンが変わったりするため、見やすくなります。
def write_chats(chats):
"""チャット履歴を表示する
Args:
chats ([dict]): チャット履歴
"""
# チャット履歴を表示
for chat in chats:
# ユーザーの発言
if chat["role"] == "user":
with st.chat_message("user"):
st.write(f"{chat['message']}")
# ボットの発言
else:
with st.chat_message("ai"):
st.write(f"{chat['message']}")
モデルのロード&予測部分
以前作った予測モデル(LightGBM)とdoc2Vecのモデルをそれぞれpickle形式でロードします。
そしてpromptを受け取ると予測結果を返す部分も作成しました。
import pickle
from gensim.utils import simple_preprocess as preprocess
def load_model():
"""doc2vecのモデルと、lightgbmのモデルをロードする
Returns:
(object, object): (doc2vecのモデル, lightgbmのモデル)
"""
doc_model = None
ml_model = None
with open("doc_model.pickle", "rb") as f:
doc_model = pickle.load(f)
with open("ml_model.pickle", "rb") as f:
ml_model = pickle.load(f)
return doc_model, ml_model
def predict_response(prompt, doc_model, ml_model):
"""ユーザーの入力から、ボットの返答を予測する
Args:
prompt (str): ユーザーの入力
doc_model (object): doc2vecのモデル
ml_model (object): lightgbmのモデル
Returns:
str: ボットの返答
"""
# doc2vecでベクトル化
vec = doc_model.infer_vector(preprocess(prompt))
# lightgbmで予測
pred = ml_model.predict([vec])
return pred[0]
メイン部分
st.session_state
を使って、データを保存するようにしています。あとは先ほど定義した関数を使っていくだけでチャットアプリの完成です。
# チャット履歴を初期化
if "chat_history" not in st.session_state:
st.session_state["chat_history"] = []
# 使いやすい形にする
chats = st.session_state["chat_history"]
# モデルをロード
if "doc_model" not in st.session_state:
st.session_state["doc_model"], st.session_state["ml_model"] = load_model()
# promptの入力を受け付ける
chats = input_prompt(chats)
# チャット履歴を出力
write_chats(chats)
動作画面
こんな感じで使えます。精度は クソ 悪いですが、もっと発展させればかなり楽しく競馬予想できるようなポテンシャルありますね...w
(例に出しているのが新馬戦の話題じゃなくて申し訳ないです...)
まとめ
実際に触ってもらおうと、streamlit.io にもあげたかったんですが、gensimがうまく入らず断念...!
しかし、すごく簡単にchatアプリを作れたのはとてもうれしいですし、もっといろんな使い方できそうですね!