有人チャットのツールを普段つくっている身として、AIによる無人チャットボットの知見をもっておく必要があると常々思っていたので、ここ最近LangChainをいじっています。
ChatGPTでの業務効率化を“断念”──正答率94%でも「ごみ出し案内」をAIに託せなかったワケ 三豊市と松尾研の半年間という記事に、「返答のリアルタイム(ストリーミング)表示に切り替えられるようにした」という改善内容があったので、今回はこちらを試しに実装してみました。
実装の概要
この実装ではLangChainを使用し、LLMに基づく応答をリアルタイムで処理し表示します。
以下が主な構成要素です。
-
ChatOpenAI
: OpenAIのモデルを利用するためのクラス。 -
ChatPromptTemplate
: チャット入力をフォーマットするためのテンプレート。 -
LLMChain
: LLMを使用した処理のチェーン。 -
StreamingHandler
: LLMからの応答をリアルタイムで処理するためのコールバックハンドラ。 - スレッドとキュー: 非同期処理とデータのやり取りを行うための要素。
設計
リアルタイム表示を実現するために、設計には以下の要素を入れました。
-
非同期処理: Pythonのスレッドを使用して、LLMの処理を非同期に実行するようにしました。これによって、LLMからの応答を待つ間にもUI(今回は未実装)がフリーズすることなく、他の処理を続けられるようにしました。
-
キューの利用とデータ整合性: スレッド間でのデータの受け渡しにキューを用いることで、複数のスレッドが同時にアクセスする際の競合を避け、データの整合性を保つようにしました。これはリアルタイムシステムにおいて重要で、データの不整合や予期せぬエラーを防ぎ、全体的なシステムの信頼性を高める役割を果たします。
-
ストリーミング処理: LLMからの応答をトークン単位で取得し、その都度処理することで、リアルタイムのフィードバックがユーザーに提供されるようにしました。
実装
コードの中身を解説します(GitHubはこちら)。
OPEN_API_KEY
を取得済であれば、Dockerで挙動を確かめていただくこともできます。
基本セットアップ
必要なライブラリをインポートし、環境変数を読み込むためにload_dotenv()
を呼び出します。これにより、外部からの設定情報(OPENAI_API_KEY
)を安全に管理できます。
import logging
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.base import BaseCallbackHandler
from dotenv import load_dotenv
from queue import Queue
from threading import Thread
load_dotenv()
StreamingHandler
StreamingHandler
クラスは、LLMからの応答をリアルタイムで処理するためのカスタムハンドラです。
LLMから新しいトークンが得られるたび、またはエラーが発生した際に呼び出されます。
-
on_llm_new_token
: 新しいトークンをキューに追加します。 -
on_llm_end
: 処理の終了をキューに通知します。 -
on_llm_error
: エラーが発生した場合にログを記録し、処理の終了をキューに通知します。
class StreamingHandler(BaseCallbackHandler):
def __init__(self, queue):
self.queue = queue
def on_llm_new_token(self, token, **kwargs):
self.queue.put(token)
def on_llm_end(self, response, **kwargs):
self.queue.put(None)
def on_llm_error(self, error, **kwargs):
logging.error(f"Error in LLM: {error}")
self.queue.put(None)
StreamingChain
StreamingChain
クラスは、LLMからのデータをストリーミングするためのメインのクラスです。
LLMの応答をリアルタイムで処理するために、スレッドとキューを使用します。
-
stream
メソッド: 入力に基づいてLLMを起動し、その結果を生成するプロセスを開始します。このプロセスは別のスレッドで実行され、メインスレッドはキューからトークンを取得し続けます。 -
cleanup
メソッド: ストリーミングが終了した後、スレッドがまだ動作していれば終了を待ちます。
class StreamingChain:
def __init__(self, llm, prompt):
self.llm_chain = LLMChain(llm=llm, prompt=prompt)
self.thread = None
def stream(self, input):
queue = Queue()
handler = StreamingHandler(queue)
def task():
self.llm_chain(input, callbacks=[handler])
self.thread = Thread(target=task)
self.thread.start()
try:
while True:
token = queue.get()
if token is None:
break
yield token
finally:
self.cleanup()
def cleanup(self):
if self.thread and self.thread.is_alive():
self.thread.join()
使用例
StreamingChain
を使用して、ユーザーの入力に基づいてLLMからの応答をリアルタイムで取得し表示する例です。
ここでは、ユーザーが「ポケモンについて100文字で説明して」と入力した場合の応答をストリーミングで表示します。
chain = StreamingChain(llm=chat, prompt=prompt)
for output in chain.stream(input={"content": "ポケモンについて100文字で説明して"}):
print(output)
参考資料