LoginSignup
3
0
記事投稿キャンペーン 「AI、機械学習」

LangServeを触ってみる on Databricks

Last updated at Posted at 2023-11-11

以前からやろうと思っていたのですが、ようやく触れました。

導入

LangServeはlangchainのChainを簡単にREST APIとしてデプロイすることができる仕組です。
FastAPI等との統合がされており、APIドキュメントの生成や、APIのPlayground機能なども提供しています。

Langchain TemplateもLangServeで試すような構造になっています。
また、Langchain側でホストするLangServe環境も今後公開される予定のようです。

実際どんな感じなんだろうと思っていたので、ローカルLLMを使ってDatabricks上で実行してみます。

Databricksにはモデルサービング機能がありますので、本番利用においてはこちらを利用する方がよいと思います。
今回は開発・検証用途という位置づけです。

Step1. サーバの構築

LangServを使ってREST APIサーバを構築します。

インストール

まずは必要なモジュールのインストール。
AWQフォーマットのモデルを使うので、autoawqもインストールします。

%pip install -U transformers accelerate langchain
%pip install fastapi uvicorn nest_asyncio "autoawq==0.1.5" "langserve[server]"

dbutils.library.restartPython()

Databricksのノートブック上でFastAPIを使えるようにするために、nest_asyncioパッケージでイベントループのネストを有効にします。(手前ミソですが、このあたりの補足はこちらの記事に記載)

import nest_asyncio
import uvicorn

nest_asyncio.apply()

モデルダウンロード

利用するモデルをダウンロードします。
今回は現在Huggingface上で人気のあるopenchat_3.5のAWQ量子化モデルを利用します。
Mistral 7Bベースの多言語モデルですが、日本語もそれなりに使えて使い勝手が良さそうな印象。

def download_model(model_id:str):
    import os
    from huggingface_hub import snapshot_download

    UC_VOLUME = "/Volumes/training/llm/model_snapshots"

    local_dir = f"/tmp/{model_id}"
    uc_dir = f"/models--{model_id.replace('/', '--')}"

    snapshot_location = snapshot_download(
        repo_id=model_id,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
    )

    dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}", recurse=True)


model_id = "TheBloke/openchat_3.5-AWQ"
download_model(model_id)

モデルのロード

先ほどダウンロードしたモデルをロードし、langchainのChat Modelとしてラップします。
ChatHuggingFaceModelクラスは、こちらの記事で作成した自作クラスを再利用しました。
もちろんローカルLLMを使わずにlangchain.llms.openai.OpenAIなどと入れ替えてもらって大丈夫です。


from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_chat import ChatHuggingFaceModel

def get_downloaded_model_path(model_id:str) -> str:
    UC_VOLUME = "/Volumes/training/llm/model_snapshots"
    uc_dir = f"/models--{model_id.replace('/', '--')}"

    return f"{UC_VOLUME}{uc_dir}"

model_path = get_downloaded_model_path(model_id)

generator = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)

chat_model = ChatHuggingFaceModel(
    generator=generator,
    tokenizer=tokenizer,
    human_message_template="GPT4 Correct User: {}<|end_of_turn|>",
    ai_message_template="GPT4 Correct Assistant: {}",
    repetition_penalty=1.2,
    temperature=0.1,
    max_new_tokens=1024,
)

プロンプトテンプレートの作成

単純に与えられた内容をそのまま実行するテンプレートにします。

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import (
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

template = """{question}"""
prompt = ChatPromptTemplate.from_messages(
    [
        HumanMessagePromptTemplate.from_template(template),
        AIMessagePromptTemplate.from_template(""),
    ]
)

チェーンの作成&LangServeの起動

ようやく今日のメイン。

FastAPIのオブジェクトを作成。

from fastapi import FastAPI, Response, Request
from fastapi.responses import HTMLResponse
from langserve import add_routes

app = FastAPI(
    title="LangChain Server",
    version="1.0",
    description="A simple api server using Langchain's Runnable interfaces",
)

ルートの登録。
add_routesを呼び出し、実行するChain(第2引数)を登録します。
今回はプロンプトテンプレートとLLM(Chat Model)だけのシンプルなChainです。

add_routes(
    app,
    prompt | chat_model,
    path="/chain",
)

最後にAPIサーバを起動。

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Step2. クライアント準備/リクエスト実行

別のノートブックを作成し、LangServeサーバと通信するクライアントを用意します。

モジュールインストール

必要なモジュールをインストール。

%pip install -U langchain "langserve[client]"

dbutils.library.restartPython()

APIサーバのURLを取得する関数を作成。
今回はサーバ側と同一のクラスタ上で動かすことを想定して、単純にhttp://127.0.0.1:8000を返すようにしています。別の環境でクライアントを実行する場合は、Driver Proxy APIを利用してください。(コメントアウト部分を外すと該当のURLが取得できます)

from dbruntime.databricks_repl_context import get_context

def api_url(port: int = 8000):
    # API を実行するための URL を取得
    # ブラウザでノートブックを開いているとき、ノートブック上で実行される各プロセスで使用できる URL を生成
    proxy_url = f"http://127.0.0.1:{port}/"
    # ctx = get_context()
    # proxy_url = f"https://{ctx.browserHostName}/driver-proxy-api/o/{ctx.workspaceId}/{ctx.clusterId}/{port}/"

    return proxy_url

リクエストを実行

通常のrequests呼び出しでもよいのですが、今回はPython用のSDKを使ってリクエストを実行します。
詳細は下記リンク先を確認ください。

RemoteRunnableを使って、リモート接続用のChainを作成。

from langchain.schema import SystemMessage, HumanMessage
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnableMap
from langserve import RemoteRunnable

# Access Tokenの設定。ローカルアドレスとの通信においては無くても動く。
api_token = dbutils.secrets.get("llm", "api_token")
headers = {"Authorization": f"Bearer {api_token}"}

# リモート接続するチェーンを作成
chain = RemoteRunnable(api_url()+"chain", headers=headers)

あとは通常のChainと同様に使えます。
単純なリクエストを実行してみます。

question = "Databricksとは何かを箇条書きで5個以内で回答してください。"
result = chain.invoke({"question": question})

print(result.content)
出力
1. Databricksは、アプリケーションを簡単に構築、デプロイ、スケールするためのデータ分析プラットフォームです。
2. データ分析のための統合環境を提供し、PythonやRを使用してデータを分析できます。
3. データをストレージし、処理するための高性能なクラスターを提供します。
4. 機械学習、データフロー、データベース、およびビジュアライザーを含む、統合されたデータ分析ツールを提供します。
5. クラウド上で実行されるため、スケーラビリティと柔軟性が高く、データ分析のスピードやスケーラビリティが向上します。

stream/astreamによるストリーミングでの出力も可能。

question = "Databricksとは何かを箇条書きで5個以内で回答してください。"

# Supports astream
async for msg in chain.astream({"question": question}):
    print(msg.content, end="", flush=True)

まとめ

langchainのChainを簡単にREST API化できるのはいいと思いました。
ただ、環境起因なのかレスポンスが返ってこないなど動作が少し不安定な感じ。
Databricks上でも安定的に使えると、開発用のRESTサーバとかに容易に使えそう。

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0