0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Databricks Apps(Streamlit)からLangChainモデルをストリーミング呼び出しする

0
Posted at

前回の記事「どうやったらDatabricks上でRAGがストリーミングしてくれるのか」では、LangChain ベースの RAG を ChatModel 形式でラップして Model Serving にデプロイし、Playground からストリーミング利用する方法を紹介しました。

本記事では、その続きとして Databricks Apps(Streamlit)をフロントエンドとしてストリーミング呼び出しする方法と注意点をまとめます。

エンドポイント側の要件

Streamlit からストリーミングで呼び出すには、エンドポイント側が ChatCompletion 形式のストリーミングレスポンスを返す必要があります。

前回の記事で紹介した通り、LangChain チェーンを ChatModel でラップし、以下を実装します:

  • predict_stream メソッド
  • ChatCompletionOutputParser による出力変換

エンドポイント作成の完全なコード

以下のノートブックを Databricks で実行してエンドポイントを作成します。

%pip install mlflow databricks-langchain langchain langchain-core langchain-community langchain-text-splitters faiss-cpu pyyaml --quiet
dbutils.library.restartPython()
import os

WORKSPACE_PATH = "/Workspace/Users/your-user/streaming_rag"
CATALOG = "your_catalog"
SCHEMA = "your_schema"
MODEL_NAME = f"{CATALOG}.{SCHEMA}.rag_agent"
ENDPOINT_NAME = "rag-agent-streaming"
LLM_ENDPOINT = "databricks-meta-llama-3-3-70b-instruct"
EMBEDDING_ENDPOINT = "databricks-bge-large-en"

os.makedirs(WORKSPACE_PATH, exist_ok=True)
# エージェント定義ファイル作成
AGENT_CODE = '''
import mlflow
from mlflow.pyfunc import ChatModel
from mlflow.types.llm import ChatMessage, ChatParams, ChatCompletionResponse, ChatCompletionChunk
from mlflow.langchain.output_parsers import ChatCompletionOutputParser
from databricks_langchain import ChatDatabricks, DatabricksEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from typing import List, Generator

mlflow.langchain.autolog()

config = mlflow.models.ModelConfig(development_config="agent_config.yaml")

LLM_ENDPOINT = config.get("llm_endpoint")
EMBEDDING_ENDPOINT = config.get("embedding_endpoint")
MAX_TOKENS = config.get("max_tokens")
RETRIEVER_K = config.get("retriever_k")

DOCUMENTS = [
    "Databricks は、データエンジニアリング、データサイエンス、機械学習のための統合プラットフォームです。",
    "Unity Catalog は Databricks のデータガバナンスソリューションです。",
    "MLflow は機械学習のライフサイクル管理ツールです。",
    "Delta Lake はオープンソースのストレージレイヤーです。",
]

embeddings = DatabricksEmbeddings(endpoint=EMBEDDING_ENDPOINT)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)
splits = text_splitter.create_documents(DOCUMENTS)
vectorstore = FAISS.from_documents(splits, embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": RETRIEVER_K})

llm = ChatDatabricks(endpoint=LLM_ENDPOINT, max_tokens=MAX_TOKENS)

prompt = ChatPromptTemplate.from_template("""
コンテキスト:
{context}

質問: {question}

回答:
""")

def format_docs(docs):
    return "\\n\\n".join(doc.page_content for doc in docs)

chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | ChatCompletionOutputParser()
)

class RAGAgent(ChatModel):
    def __init__(self):
        self.chain = chain
    
    def _get_question(self, messages: List[ChatMessage]) -> str:
        for msg in reversed(messages):
            if msg.role == "user":
                return msg.content
        return ""
    
    def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatCompletionResponse:
        question = self._get_question(messages)
        response = self.chain.invoke(question)
        return ChatCompletionResponse.from_dict(response)
    
    def predict_stream(self, context, messages: List[ChatMessage], params: ChatParams) -> Generator[ChatCompletionChunk, None, None]:
        question = self._get_question(messages)
        for chunk in self.chain.stream(question):
            yield ChatCompletionChunk.from_dict(chunk)

mlflow.models.set_model(RAGAgent())
'''

agent_path = f"{WORKSPACE_PATH}/rag_agent.py"
with open(agent_path, "w") as f:
    f.write(AGENT_CODE)

CONFIG_YAML = f'''
llm_endpoint: "{LLM_ENDPOINT}"
embedding_endpoint: "{EMBEDDING_ENDPOINT}"
max_tokens: 500
retriever_k: 3
'''

config_path = f"{WORKSPACE_PATH}/agent_config.yaml"
with open(config_path, "w") as f:
    f.write(CONFIG_YAML)
# モデルログ
import mlflow
from mlflow.models.resources import DatabricksServingEndpoint

mlflow.set_registry_uri("databricks-uc")

input_example = {"messages": [{"role": "user", "content": "MLflow について教えてください"}]}

with mlflow.start_run(run_name="rag_agent") as run:
    model_info = mlflow.pyfunc.log_model(
        python_model=agent_path,
        artifact_path="agent",
        model_config=config_path,
        input_example=input_example,
        pip_requirements=[
            "mlflow", "databricks-langchain", "langchain", "langchain-core",
            "langchain-community", "langchain-text-splitters", "faiss-cpu", "pyyaml",
        ],
        resources=[
            DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT),
            DatabricksServingEndpoint(endpoint_name=EMBEDDING_ENDPOINT),
        ],
        registered_model_name=MODEL_NAME,
    )
    print(f"Model URI: {model_info.model_uri}")
# エンドポイント作成
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedEntityInput
from mlflow import MlflowClient

w = WorkspaceClient()
mlflow_client = MlflowClient()

model_versions = mlflow_client.search_model_versions(f"name='{MODEL_NAME}'")
latest_version = max(model_versions, key=lambda x: int(x.version)).version

config = EndpointCoreConfigInput(
    name=ENDPOINT_NAME,
    served_entities=[
        ServedEntityInput(
            entity_name=MODEL_NAME,
            entity_version=str(latest_version),
            workload_size="Small",
            scale_to_zero_enabled=True,
        )
    ],
)

try:
    w.serving_endpoints.get(ENDPOINT_NAME)
    w.serving_endpoints.update_config(ENDPOINT_NAME, served_entities=config.served_entities)
except:
    w.serving_endpoints.create(name=ENDPOINT_NAME, config=config)

Screenshot 2026-01-05 at 10.24.23.png

Streamlit 側の実装

リクエスト

stream=True を payload と requests.post() の両方に指定します:

payload = {
    "messages": [{"role": "user", "content": prompt}],
    "stream": True
}

with requests.post(url, headers=headers, json=payload, stream=True) as response:
    ...

SSE レスポンスのパース

Model Serving のストリーミングレスポンスは Server-Sent Events(SSE)形式です。各行が data: {JSON} の形式で返されます:

for line in response.iter_lines():
    if line:
        line_str = line.decode("utf-8")
        if line_str.startswith("data: "):
            data_str = line_str[6:]  # "data: " を除去
            if data_str == "[DONE]":
                break
            data = json.loads(data_str)
            content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
            if content:
                yield content

st.write_stream() で表示

Streamlit の st.write_stream() にジェネレータを渡すと、トークンが順次表示されます:

with st.chat_message("assistant"):
    response = st.write_stream(stream_response(prompt))

アプリを作成する際、エンドポイントをリソースとして設定します。

Screenshot 2026-01-05 at 10.24.10.png

完全なサンプルコード

app.py

import streamlit as st
import requests
import json
import os
from databricks.sdk import WorkspaceClient

ENDPOINT_NAME = os.environ.get("ENDPOINT_NAME", "rag-agent-streaming")

def stream_response(prompt: str):
    w = WorkspaceClient()
    
    host = w.config.host
    if not host.startswith("https://"):
        host = f"https://{host}"
    
    url = f"{host}/serving-endpoints/{ENDPOINT_NAME}/invocations"
    
    headers = w.config.authenticate()
    headers["Content-Type"] = "application/json"
    
    payload = {"messages": [{"role": "user", "content": prompt}], "stream": True}
    
    with requests.post(url, headers=headers, json=payload, stream=True) as response:
        if response.status_code != 200:
            yield f"エラー: {response.status_code} - {response.text}"
            return
        
        for line in response.iter_lines():
            if line:
                line_str = line.decode("utf-8")
                if line_str.startswith("data: "):
                    data_str = line_str[6:]
                    if data_str == "[DONE]":
                        break
                    try:
                        data = json.loads(data_str)
                        content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
                        if content:
                            yield content
                    except:
                        pass

st.title("RAG Streaming Demo")

if "messages" not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])

if prompt := st.chat_input("質問を入力してください"):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.write(prompt)
    
    with st.chat_message("assistant"):
        response = st.write_stream(stream_response(prompt))
        st.session_state.messages.append({"role": "assistant", "content": response})

app.yaml

command:
  - streamlit
  - run
  - app.py

env:
  - name: ENDPOINT_NAME
    value: "your-endpoint-name"
  - name: DATABRICKS_HOST
    value: "https://your-workspace.cloud.databricks.com"

requirements.txt

streamlit
requests
databricks-sdk

Appでもストリーミングされました。

streaming_app_demo.gif

まとめ

クライアント ストリーミング対応 必要な実装
AI Playground 自動 なし(エンドポイントがpredict_streamを実装していれば動作)
Databricks Apps (Streamlit) 手動 stream=Trueの指定、SSEパース、st.write_stream()
要件
エンドポイント ChatModelpredict_stream + ChatCompletionOutputParser
リクエスト stream=True を payload と requests 両方に指定
レスポンス SSE 形式をパースして delta.content を抽出
表示 st.write_stream() にジェネレータを渡す

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?