はじめに
StreamlitとLangchainを組み合わせたときに、単純に処理を組むとChatGPTのようにストリーム表示(応答をリアルタイムに表示)になりません。
順当なやり方かどうかはわかりませんがうまくいった方法を共有しようと思います。
Streamlit用のコールバックハンドラー
LangchainにはAPIのレスポンスをハンドリングするハンドラークラスが用意されています。
ドキュメント上には書かれていません1が、リポジトリ上にStreamlit用のクラス(StreamlitCallbackHandler)があるのでこれを利用します。
しかし、単純にこのクラスを利用するとLangchainが初回に送っているプロンプトの内容なども表示されてしまうので、純粋にレスポンスのみを表示するためにはクラスをラップします。
ラッパークラスをつくる
以下のようにラッパークラスを定義し、コールバックハンドラーとして指定します。(WrapStreamlitCallbackHandler)
class WrapStreamlitCallbackHandler(StreamlitCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
llm = ChatOpenAI(
temperature=0,
model_name=model,
streaming=True,
max_tokens=2000,
callback_manager=BaseCallbackManager(
[WrapStreamlitCallbackHandler()],
),
)
例えば、pandas_dataframe_agentを使った場合は、以下のようにするといい感じで動きました。(exceptのあたりはAgentを使ったときによく起こるパースエラーの対策です。)
import os
from typing import Any, Dict, List
import pandas as pd
import streamlit as st
from langchain.agents import create_pandas_dataframe_agent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.streamlit import StreamlitCallbackHandler
from langchain.chat_models import ChatOpenAI
# promptsの出力を行わないためラップ
class WrapStreamlitCallbackHandler(StreamlitCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.tokens_stream += token
if "Final Answer:" in self.tokens_stream:
target = "Final Answer:"
idx_s = self.tokens_stream.find(target)
self.tokens_area.write(self.tokens_stream[idx_s + len(target) :])
os.environ["OPENAI_API_KEY"] = st.secrets["OPEN_AI_KEY"]
model = "gpt-3.5-turbo"
file = st.file_uploader("Choose a file", type=["csv"])
text = st.text_input("入力")
if file:
df = pd.read_csv(file)
agent = create_pandas_dataframe_agent(llm, df=df, verbose=True)
if text:
try:
with st.spinner("考え中..."):
st.write("## 回答")
llm = ChatOpenAI(
temperature=0,
model_name=model,
streaming=True,
max_tokens=2000,
callback_manager=BaseCallbackManager(
[WrapStreamlitCallbackHandler()],
),
)
response = agent.run(text + "Be sure to answer in Japanese.")
except Exception as e:
response = str(e)
if not response.startswith("Could not parse LLM output: `"):
raise e
response = response.removeprefix("Could not parse LLM output: `").removesuffix(
"`"
)
おわりに
Langchain自体がすごい勢いでアップデートされているので、しばらくしたら正式にドキュメントに載る&パラメータなども充実してくるかもしれません。
-
GPT4Allのところでこっそり紹介されてます。(https://python.langchain.com/en/latest/ecosystem/gpt4all.html?highlight=StreamlitCallbackHandler) ↩