langchainのdiscordでも質問多発する「どうやって文字ストリーミングするの?」問題
LLMアプリをつくる人なら誰しもChatGPTのように一文字(正確にはトークン)ずつカタカタカタっとアウトプットしたいと思うはず。ただの見栄えの問題にも見えますが、出力し始めて、「あ、これは質問の意味履き違えてるな」と気づけたりするので特に長い回答の場合はストリーミングでのアウトプットは助かります。
実際にlangchainのdiscordコミュニティでも超頻出する質問です。
意外と共有されていないストリーミングの実装方法
標準出力のストリーミングのやり方自体はさらっと載っていて、簡単です。(LLMをインスタンス化する時にstremingをTrueにするのみ)
https://python.langchain.com/en/latest/modules/models/chat/getting_started.html?highlight=streaming#streaming
しかし、これをどうやって実際のアプリで出力するのか?の情報はなかなか見つかりませんでした。人によってはトークンを全部受け取ったあとに、一トークンずつ表示する、ということをやっている人もいました。しかしこれでは「ダメなアウトプットに早めに気づく」ということは出来ません。むしろもっと長く待たされることになります。
これを実装するにはCallbacksを使う必要があります。
カルボナーラ以外全部朝ごはんっぽいですね。。
手っ取り早く試したい人はこちら
まずはlangchainから
from langchain.callbacks.base import BaseCallbackHandler
をインポートした後以下のコールバックハンドラーを作ります
class SimpleStreamlitCallbackHandler(BaseCallbackHandler):
""" Copied only streaming part from StreamlitCallbackHandler """
def __init__(self) -> None:
self.tokens_area = st.empty()
self.tokens_stream = ""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.tokens_stream += token
self.tokens_area.markdown(self.tokens_stream)
これをインスタンス化
handler = SimpleStreamlitCallbackHandler()
して下記のように入れ込みます。ポイントはst.emptyなどの書けるスペースを用意してそこの文字列(markdownなど)をChainに更新させることです。
*callbacksはllmやchainなど色んなところに入れられますが、この場合はpredictの中じゃないと正しく動きませんでした。ここは僕も良くわかってません。
res_box = st.empty()
chat = ChatOpenAI(streaming=True, temperature=0.9)
conversation = ConversationChain(
llm=chat,
prompt=prompt,
memory=state['memory']
)
res = conversation.predict(input=user_input, callbacks=[handler])
*実はStreamlitCallbackHandlerというど直球なハンドラーがあるのですが、入らないコールバックが沢山あったので、シンプルにする意味で独自でクラスを作っています。
動くコード全体像
https://streaming-and-memory.streamlit.app/
実際に試してみてください。
import streamlit as st
from langchain. chat_models import ChatOpenAI
from langchain import PromptTemplate
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import (
HumanMessage,
)
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
import openai
from typing import Any, Dict, List
st.header("AMA")
st.subheader("Streamlit + ChatGPT + Langchain with `stream=True`")
def get_state():
if "state" not in st.session_state:
st.session_state.state = {"memory": ConversationBufferMemory(memory_key="chat_history")}
return st.session_state.state
state = get_state()
st.write(state['memory'].load_memory_variables({}))
prompt = PromptTemplate(
input_variables=["chat_history","input"],
template='Based on the following chat_history, Please reply to the question in format of markdown. history: {chat_history}. question: {input}'
)
user_input = st.text_input("You: ",placeholder = "Ask me anything ...")
ask = st.button('ask',type='primary')
st.markdown("----")
class SimpleStreamlitCallbackHandler(BaseCallbackHandler):
""" Copied only streaming part from StreamlitCallbackHandler """
def __init__(self) -> None:
self.tokens_area = st.empty()
self.tokens_stream = ""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.tokens_stream += token
self.tokens_area.markdown(self.tokens_stream)
handler = SimpleStreamlitCallbackHandler()
if ask:
res_box = st.empty()
with st.spinner('typing...'):
report = []
chat = ChatOpenAI(streaming=True, temperature=0.9)
conversation = ConversationChain(
llm=chat,
prompt=prompt,
memory=state['memory']
)
res = conversation.predict(input=user_input, callbacks=[handler])
st.markdown("----")
Callbacksと仲良くなろう!
さて、上記で一応動きますが、せっかくなのでコールバックと仲良くなりましょう。
langchainにおけるコールバックは簡単に言うと、LLMが動きのタイミングに合わせてお好きな関数が走るようにできるやつです。
タイミングには
- on_llm_start
- on_llm_new_token
- on_llm_end
など、基本的には何かが始まるか終わるタイミングで関数を走らせられます。llmの他にもchain、tool、agentなど様々なアクターの始まったり終わったりするタイミングにも設定できます。
この中で今回使ったのはon_llm_new_tokenです。これはトークンが生成される度に、
self.tokens_stream += token
self.tokens_area.markdown(self.tokens_stream)
と言う形でまず既存の文字列に新たに文字を追加して、area(今回の例ではst.empty())にマークダウンとして表示しています。
まさに今回の目的にぴったりですね。
注意点
2023年6月時点でCallbacksは導入されて日が浅いため、様々なChain、Agentでは旧来のCallbackManagerが使われていてCallbacksを渡しても無視される可能性があります。動かない場合はGithubのコードを読んでcallbacksが受け取れるようになってるか確認して、まだの場合は是非PRだすかイシューを出しましょう。
https://github.com/hwchase17/langchain
いますぐ実装する必要があれば既存のChainやAgentを改造してcallbacksを受け取れるようにすれば大丈夫ですが、メンテナンスが大変なのでおすすめはしません。
てか文字ストリームするのにここまでカロリーかけるのも、って感じしますよね🤣でもこれほんとみんながハマってるポイントで僕も色々聞いて回ってやっと解決できました。
同じくstreamlitで実装時にハマるポイントとしてメモリの管理の仕方をまとめたのでこちらもどうぞ
https://qiita.com/yazoo/items/70b7cf73eb3b232ca423