こちらの続きです。
前回はDatabricksノートブックにストリーミングを表示しただけでしたが、今回はstreamlitにストリーミングします。
そして、毎度のことですが @isanakamishiro2 さんありがとうございます!
前回試した際にどうしてもstreamlitでストリーミングされず、どうしたものかなと思っていたのですが助かりました!
APIサーバのURLは、同一クラスタを前提とした127.0.0.1を直接指定しています。
これをdriver-proxy-apiのURL指定にすることもできるのですが、この場合、レスポンスが遅い(すべて終わってからストリームでデータが配信される)。
サーバの実装
%pip install langchain==0.0.166 tiktoken==0.4.0 openai==0.27.6 faiss-cpu==1.7.4
Open AIのAPIキーを設定します。
os.environ['OPENAI_API_KEY'] = dbutils.secrets.get("demo-token-takaaki.yayoi", "openai_api_key")
今回、LangChainも使っているので、色々インポートします。
import re
import time
import pandas as pd
import mlflow
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
# for streaming
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage
)
from langchain.vectorstores.faiss import FAISS
from langchain.schema import BaseRetriever
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts import PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain import LLMChain
こちらを参考にさせていただいています。ChainStreamHandler
でストリーミングのコールバックを設定しているくらいしか中身はまだよくわかってません。
import threading
import queue
class ThreadedGenerator:
def __init__(self):
self.queue = queue.Queue()
def __iter__(self):
return self
def __next__(self):
item = self.queue.get()
if item is StopIteration: raise item
return item
def send(self, data):
self.queue.put(data)
def close(self):
self.queue.put(StopIteration)
class ChainStreamHandler(StreamingStdOutCallbackHandler):
def __init__(self, gen):
super().__init__()
self.gen = gen
def on_llm_new_token(self, token: str, **kwargs):
self.gen.send(token)
def llm_thread(g, prompt):
try:
chat = ChatOpenAI(
verbose=True,
streaming=True,
callbacks=[ChainStreamHandler(g)],
temperature=0.7,
)
chat([HumanMessage(content=prompt)])
finally:
g.close()
def chain(prompt):
g = ThreadedGenerator()
threading.Thread(target=llm_thread, args=(g, prompt)).start()
return g
from flask import Flask, jsonify, request, Response, stream_with_context
app = Flask("rag-chatbot-streaming")
app.config["JSON_AS_ASCII"] = False
@app.route('/', methods=['POST'])
def serve():
return Response(chain(request.json['prompt']), mimetype='text/event-stream')
from dbruntime.databricks_repl_context import get_context
ctx = get_context()
port = "7777"
driver_proxy_api = f"https://{ctx.browserHostName}/driver-proxy-api/o/0/{ctx.clusterId}/{port}"
print(f"""
driver_proxy_api = '{driver_proxy_api}'
cluster_id = '{ctx.clusterId}'
port = {port}
""")
こちらでサーバを起動。
app.run(host="0.0.0.0", port=port, debug=True, threaded=True, use_reloader=False)
streamlitで画面を構築
ノートブックではなくstreaming_sample.py
というpyファイルを作成します。
import streamlit as st
import numpy as np
import base64
import io
import os
import requests
import numpy as np
import pandas as pd
import json
st.header('サンプル Q&A bot')
def get_answer(question, placeholder):
token = ""
url = "http://127.0.0.1:7777"
headers = {'Authorization': f'Bearer {token}',
"Content-Type": "application/json",}
data = {"prompt": question}
headers = {"Accept": "text/event-stream"}
output_text = ""
s = requests.Session()
with s.post(url, headers=headers, json=data, stream=True) as resp:
for line in resp.iter_content(decode_unicode=True):
if line:
output_text += line
placeholder.markdown(output_text + "▌")
return output_text
return output_text
if "messages" not in st.session_state:
st.session_state.messages = []
# アプリの再実行の際に履歴のチャットメッセージを表示
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# ユーザー入力に対する反応
if prompt := st.chat_input("質問はなんでしょうか?"):
# チャットメッセージコンテナにユーザーメッセージを表示
st.chat_message("user").markdown(prompt)
# チャット履歴にユーザーメッセージを追加
st.session_state.messages.append({"role": "user", "content": prompt})
# チャットメッセージコンテナにアシスタントのレスポンスを表示
with st.chat_message("assistant"):
message_placeholder = st.empty()
output_text = get_answer(prompt, message_placeholder)
message_placeholder.markdown(output_text)
# チャット履歴にアシスタントのレスポンスを追加
st.session_state.messages.append({"role": "assistant", "content": output_text})
streamlitの起動
別のノートブックを作成します。
%pip install streamlit watchdog
dbutils.library.restartPython()
from dbruntime.databricks_repl_context import get_context
def front_url(port):
"""
フロントエンドを実行するための URL を返す
Returns
-------
proxy_url : str
フロントエンドのURL
"""
ctx = get_context()
proxy_url = f"https://{ctx.browserHostName}/driver-proxy/o/{ctx.workspaceId}/{ctx.clusterId}/{port}/"
return proxy_url
PORT = 1501
# Driver ProxyのURLを表示
print(front_url(PORT))
# 利便性のためにリンクをHTML表示
displayHTML(f"<a href='{front_url(PORT)}' target='_blank' rel='noopener noreferrer'>別ウインドウで開く</a>")
streamlitを起動します。
streamlit_file = "/Workspace/Users/takaaki.yayoi@databricks.com/20230915_streaming_llmbot_w_driver_proxy/streaming_sample.py"
!streamlit run {streamlit_file} --server.port {PORT}
動きました!
次のステップ
LangChainを用いたカスタムモデルでも動作するようにします。