LoginSignup
4

DatabricksのドライバープロキシーでOpenAI API + streamlitのストリーミングを行う

Posted at

こちらの続きです。

前回はDatabricksノートブックにストリーミングを表示しただけでしたが、今回はstreamlitにストリーミングします。

そして、毎度のことですが @isanakamishiro2 さんありがとうございます!

前回試した際にどうしてもstreamlitでストリーミングされず、どうしたものかなと思っていたのですが助かりました!

APIサーバのURLは、同一クラスタを前提とした127.0.0.1を直接指定しています。

これをdriver-proxy-apiのURL指定にすることもできるのですが、この場合、レスポンスが遅い(すべて終わってからストリームでデータが配信される)。

サーバの実装

%pip install langchain==0.0.166 tiktoken==0.4.0 openai==0.27.6 faiss-cpu==1.7.4

Open AIのAPIキーを設定します。

os.environ['OPENAI_API_KEY'] = dbutils.secrets.get("demo-token-takaaki.yayoi", "openai_api_key")

今回、LangChainも使っているので、色々インポートします。

import re
import time
import pandas as pd
import mlflow

from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

# for streaming
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackHandler

from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)

from langchain.vectorstores.faiss import FAISS
from langchain.schema import BaseRetriever
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts import PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain import LLMChain

こちらを参考にさせていただいています。ChainStreamHandlerでストリーミングのコールバックを設定しているくらいしか中身はまだよくわかってません。

import threading
import queue

class ThreadedGenerator:
    def __init__(self):
        self.queue = queue.Queue()

    def __iter__(self):
        return self

    def __next__(self):
        item = self.queue.get()
        if item is StopIteration: raise item
        return item

    def send(self, data):
        self.queue.put(data)

    def close(self):
        self.queue.put(StopIteration)

class ChainStreamHandler(StreamingStdOutCallbackHandler):
    def __init__(self, gen):
        super().__init__()
        self.gen = gen

    def on_llm_new_token(self, token: str, **kwargs):
        self.gen.send(token)

def llm_thread(g, prompt):
    try:
        chat = ChatOpenAI(
            verbose=True,
            streaming=True,
            callbacks=[ChainStreamHandler(g)],
            temperature=0.7,
        )
        chat([HumanMessage(content=prompt)])
    finally:
        g.close()


def chain(prompt):
    g = ThreadedGenerator()
    threading.Thread(target=llm_thread, args=(g, prompt)).start()
    return g
from flask import Flask, jsonify, request, Response, stream_with_context

app = Flask("rag-chatbot-streaming")
app.config["JSON_AS_ASCII"] = False

@app.route('/', methods=['POST'])
def serve():
    return Response(chain(request.json['prompt']), mimetype='text/event-stream')
from dbruntime.databricks_repl_context import get_context
ctx = get_context()

port = "7777"
driver_proxy_api = f"https://{ctx.browserHostName}/driver-proxy-api/o/0/{ctx.clusterId}/{port}"

print(f"""
driver_proxy_api = '{driver_proxy_api}'
cluster_id = '{ctx.clusterId}'
port = {port}
""")

こちらでサーバを起動。

app.run(host="0.0.0.0", port=port, debug=True, threaded=True, use_reloader=False)

streamlitで画面を構築

ノートブックではなくstreaming_sample.pyというpyファイルを作成します。

streaming_sample.py
import streamlit as st 
import numpy as np 
import base64
import io

import os
import requests
import numpy as np
import pandas as pd

import json

st.header('サンプル Q&A bot')

def get_answer(question, placeholder):
  token = ""
  url = "http://127.0.0.1:7777"

  headers = {'Authorization': f'Bearer {token}',
             "Content-Type": "application/json",}

  data = {"prompt": question}

  headers = {"Accept": "text/event-stream"}
  output_text = ""

  s = requests.Session()
  with s.post(url, headers=headers, json=data, stream=True) as resp:
        for line in resp.iter_content(decode_unicode=True):
            if line:
                output_text += line
                placeholder.markdown(output_text + "")
        return output_text
  
  return output_text


if "messages" not in st.session_state:
    st.session_state.messages = []

# アプリの再実行の際に履歴のチャットメッセージを表示
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# ユーザー入力に対する反応
if prompt := st.chat_input("質問はなんでしょうか?"):
    # チャットメッセージコンテナにユーザーメッセージを表示
    st.chat_message("user").markdown(prompt)
    # チャット履歴にユーザーメッセージを追加
    st.session_state.messages.append({"role": "user", "content": prompt})

    # チャットメッセージコンテナにアシスタントのレスポンスを表示
    with st.chat_message("assistant"):
      message_placeholder = st.empty()

      output_text = get_answer(prompt, message_placeholder)

      message_placeholder.markdown(output_text)

    # チャット履歴にアシスタントのレスポンスを追加
    st.session_state.messages.append({"role": "assistant", "content": output_text})

streamlitの起動

別のノートブックを作成します。

%pip install streamlit watchdog
dbutils.library.restartPython()
from dbruntime.databricks_repl_context import get_context

def front_url(port):
    """
    フロントエンドを実行するための URL を返す

    Returns
    -------
    proxy_url : str
        フロントエンドのURL
    """
    ctx = get_context()
    proxy_url = f"https://{ctx.browserHostName}/driver-proxy/o/{ctx.workspaceId}/{ctx.clusterId}/{port}/"

    return proxy_url

PORT = 1501

# Driver ProxyのURLを表示
print(front_url(PORT))

# 利便性のためにリンクをHTML表示
displayHTML(f"<a href='{front_url(PORT)}' target='_blank' rel='noopener noreferrer'>別ウインドウで開く</a>")

streamlitを起動します。

streamlit_file = "/Workspace/Users/takaaki.yayoi@databricks.com/20230915_streaming_llmbot_w_driver_proxy/streaming_sample.py"

!streamlit run {streamlit_file} --server.port {PORT}

動きました!

次のステップ

LangChainを用いたカスタムモデルでも動作するようにします。

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
4