1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

DatabricksのドライバープロキシーでOpenAI APIのストリーミングを行う

Last updated at Posted at 2023-08-25

なかなかニッチだと思いつつ、少し進捗が出たのでまとめます。

要件

APIでOpenAIのコンプリーションを取得する際、デフォルトではすべての生成が終わってから返却されます。プロンプトによっては数十秒待つことに。これを改善するために、コンプリーションをストリームすることができます。

OpenAI APIを直接呼び出すのであればシンプルです。こちらの記事に書いているように引数streamTrueに設定するだけです。

ただ、実際にはAPIの呼び出し前や呼び出し後に処理を追加したいケースが多く、その場合には前処理、後処理を含むロジックを一連のパイプラインとしてデプロイする必要が出てきます。

そのため、ここではDatabricksのドライバープロキシーを用いてFlaskに処理のパイプラインをデプロイして、クライアントから呼び出す構成を組みます。その際に生成結果をストリーミングするようにします。

実装

サーバー

import os
os.environ["OPENAI_API_KEY"] = dbutils.secrets.get("demo-token-takaaki.yayoi", "openai_api_key")
import openai

def chat_gpt_helper(prompt):
    """
    この関数はOpenAIのcompletions APIを用いてGpt4モデルのレスポンスを返却します
    """
    try:
        resp = ''
        openai.api_key = os.getenv('OPENAI_API_KEY')
        for chunk in openai.ChatCompletion.create(
            model="gpt-4",
            messages=[{
                "role": "user",
                "content":prompt
            }],
            stream=True,
        ):
            content = chunk["choices"][0].get("delta", {}).get("content")
            if content is not None:
                    yield content

    except Exception as e:
        print(e)
        return str(e)

Flask側でもMIMEタイプtext/event-streamでストリームを搬出しています。

from flask import Flask, jsonify, request, Response, stream_with_context

app = Flask("openai-streaming")

@app.route('/', methods=['POST'])
def stream_chat_gpt():
        """
        ChatGPTのレスポンスをストリーミングします
        """
        prompt = request.get_json(force = True).get('prompt','')
        return Response(stream_with_context(chat_gpt_helper(prompt)),
                         mimetype='text/event-stream')

こちらはサーバーのURLやポート番号を確認するためのものです。

from dbruntime.databricks_repl_context import get_context
ctx = get_context()

port = "7777"
driver_proxy_api = f"https://{ctx.browserHostName}/driver-proxy-api/o/0/{ctx.clusterId}/{port}"

print(f"""
driver_proxy_api = '{driver_proxy_api}'
cluster_id = '{ctx.clusterId}'
port = {port}
""")
driver_proxy_api = 'https://xxxxxx.databricks.com/driver-proxy-api/o/0/0713-082554-nqq9eafd/7777'
cluster_id = '0713-082554-nqq9eafd'
port = 7777

以下のコマンドでドライバープロキシーを起動します。明示的に停止しない限り実行され続けるので注意してください。

app.run(host="0.0.0.0", port=port, debug=True, use_reloader=False, threaded=True)

クライアント

今回は同じDatabricksワークスペースの同じクラスターに別のノートブックをアタッチします。

クライアント側もtext/event-streamを受け付けるようにしています。

client
import requests
import json

def get_stream(data):

    headers = {"Accept": "text/event-stream"}
    output_text = []

    s = requests.Session()
    with s.post("http://127.0.0.1:7777/", headers=headers, json=data, stream=True) as resp:
        for line in resp.iter_content(decode_unicode=True):
            if line:
                print(line, end="", flush=True)
                output_text.append(line)
        return "".join(output_text)

プロンプトを入力してみます。

data = {"prompt": "猫を讃える歌を歌ってください"}
get_stream(data)

動きました!

参考資料

色々参考にさせていただいています。

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?