LoginSignup
18
11

LangChainのレスポンスをStreamlitでStream表示する

Posted at

はじめに

StreamlitとLangchainを組み合わせたときに、単純に処理を組むとChatGPTのようにストリーム表示(応答をリアルタイムに表示)になりません。

順当なやり方かどうかはわかりませんがうまくいった方法を共有しようと思います。

Streamlit用のコールバックハンドラー

LangchainにはAPIのレスポンスをハンドリングするハンドラークラスが用意されています。

ドキュメント上には書かれていません1が、リポジトリ上にStreamlit用のクラス(StreamlitCallbackHandler)があるのでこれを利用します。

しかし、単純にこのクラスを利用するとLangchainが初回に送っているプロンプトの内容なども表示されてしまうので、純粋にレスポンスのみを表示するためにはクラスをラップします。

ラッパークラスをつくる

以下のようにラッパークラスを定義し、コールバックハンドラーとして指定します。(WrapStreamlitCallbackHandler)

Stream
class WrapStreamlitCallbackHandler(StreamlitCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        pass

llm = ChatOpenAI(
    temperature=0,
    model_name=model,
    streaming=True,
    max_tokens=2000,
    callback_manager=BaseCallbackManager(
        [WrapStreamlitCallbackHandler()],
    ),
)

例えば、pandas_dataframe_agentを使った場合は、以下のようにするといい感じで動きました。(exceptのあたりはAgentを使ったときによく起こるパースエラーの対策です。)

Sidebar
import os
from typing import Any, Dict, List

import pandas as pd
import streamlit as st
from langchain.agents import create_pandas_dataframe_agent
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.streamlit import StreamlitCallbackHandler
from langchain.chat_models import ChatOpenAI


# promptsの出力を行わないためラップ
class WrapStreamlitCallbackHandler(StreamlitCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        pass

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        self.tokens_stream += token
        if "Final Answer:" in self.tokens_stream:
            target = "Final Answer:"
            idx_s = self.tokens_stream.find(target)
            self.tokens_area.write(self.tokens_stream[idx_s + len(target) :])


os.environ["OPENAI_API_KEY"] = st.secrets["OPEN_AI_KEY"]

model = "gpt-3.5-turbo"

file = st.file_uploader("Choose a file", type=["csv"])
text = st.text_input("入力")

if file:
    df = pd.read_csv(file)
    agent = create_pandas_dataframe_agent(llm, df=df, verbose=True)

    if text:
        try:
            with st.spinner("考え中..."):
                st.write("## 回答")
                llm = ChatOpenAI(
                    temperature=0,
                    model_name=model,
                    streaming=True,
                    max_tokens=2000,
                    callback_manager=BaseCallbackManager(
                        [WrapStreamlitCallbackHandler()],
                    ),
                )

                response = agent.run(text + "Be sure to answer in Japanese.")
        except Exception as e:
            response = str(e)
            if not response.startswith("Could not parse LLM output: `"):
                raise e
            response = response.removeprefix("Could not parse LLM output: `").removesuffix(
                "`"
            )

Animation.gif

おわりに

Langchain自体がすごい勢いでアップデートされているので、しばらくしたら正式にドキュメントに載る&パラメータなども充実してくるかもしれません。

  1. GPT4Allのところでこっそり紹介されてます。(https://python.langchain.com/en/latest/ecosystem/gpt4all.html?highlight=StreamlitCallbackHandler)

18
11
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
11