0
0

Databricks Model ServingでRerank用APIを作る

Posted at

こちらの続き?です。

導入

以前、こちらの記事でRAG処理におけるRerankについてやりました。

独立したAPIサーバがあると利便性が高いため、こちらをDatabricks Model Servingを使ってサービングしてみます。

また、日本語対応のColBERTv2モデルが出ていますので、モデルをこちらに置き換えます。
日本語性能のベンチマークでは、かなりよい結果のモデルです。

以下コードの実行環境であるDatabricksのDBRは15.2ML、インスタンスタイプはg4dn.xlarge(AWS)です。

Step1. パッケージインストール

ragatouilleによってこのモデルを使ったRerankを実行しますので、必要なパッケージをインストールします。

%pip install "ragatouille==0.0.8post1" fugashi unidic_lite
%pip install "mlflow-skinny[databricks]>=2.13.0"

dbutils.library.restartPython()

Step2. Pyfuncカスタムモデルクラスの作成

Pyfuncのカスタムモデルクラスを定義します。
中身は非常にシンプルで、load_context内でragatouilleを使ってモデルをロードし、predict内でrerankメソッドを実行しているだけです。
※ 実用にはエラーハンドリングが必要ですが、割愛。

import mlflow
import pandas as pd

# Define a custom PythonModel
class RagatouilleRerankModel(mlflow.pyfunc.PythonModel):

    def load_context(self, context):
        from ragatouille import RAGPretrainedModel

        model_directory = context.artifacts["model"]

        reranker = RAGPretrainedModel.from_pretrained(model_directory)

        self._reranker = reranker

    def predict(self, context, model_input, params=None):

        query = model_input["query"][0]
        documents = model_input["document"].tolist()

        k = params.get("k", 10) if params else 10

        result = self._reranker.rerank(query=query, documents=documents, k=k)        

        return pd.DataFrame(result)

Step3. MLFlowへのモデルロギング

Mlflowへ作成したカスタムモデルをロギングします。

まずは、入出力のシグネチャや使用するパッケージの情報を作成。

import pandas as pd
import mlflow
from mlflow.models.signature import infer_signature

inputs = pd.DataFrame(
    [
        {"query": "Databricksとは何?", "document": "Document1"},
        {"query": "Databricksとは何?", "document": "Document2"},
    ]
)
outputs = pd.DataFrame(
    [
        {"content": "Sample output", "score": 22.0, "rank": 0, "result_index": 1},
        {"content": "Sample output", "score": 21.0, "rank": 1, "result_index": 0},
    ]
)
params = {"k": 10}

# Define signature
signature = infer_signature(inputs, outputs, params)

# Define input example
input_example = inputs

# Define Library
pip_requirements = mlflow.pyfunc.get_default_pip_requirements() + [
    "ragatouille==0.0.8post1",
    "fugashi",
    "unidic_lite",
]

そして、モデルをロギング。

import mlflow

mlflow.set_registry_uri("databricks-uc")

# Unity Catalog内の適当な場所へモデルを作成
registered_model_name = "training.llm.ragatouille_reranker"

model = RagatouilleRerankModel()
# 事前にモデルを保管しておいた場所を指定
model_path = "/Volumes/training/llm/model_snapshots/models--bclavie--JaColBERTv2/"

with mlflow.start_run() as run:
    _ = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=model,
        artifacts={
            "model": model_path,
        },
        pip_requirements=pip_requirements,
        signature=signature,
        input_example=input_example,
        await_registration_for=1200,  # モデルサイズが大きいので長めの待ち時間にします
        registered_model_name=registered_model_name,
    )

Step4. モデルサービングエンドポイントを作成

モデルサービングエンドポイントを、REST APIを呼び出すことで作成。
(モデルサービングのUI上でも可能です)

import requests
import json
from pprint import pprint
from mlflow import MlflowClient
import mlflow

mlflow.set_registry_uri("databricks-uc")

# 最新バージョンを取得
client=MlflowClient()
versions = [mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")]

# APIエンドポイントのURLとトークンを取得
API_ROOT = (
    dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
)
API_TOKEN = (
    dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
)

data = {
    "name": "ragatouille-reranker-endpoint",
    "config": {
        "served_entities": [
            {
                "entity_name": registered_model_name,
                "entity_version": versions[0],
                "workload_type": "GPU_SMALL",
                "workload_size": "Small",
                "scale_to_zero_enabled": True,
            }
        ]
    },
}

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.post(
    url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers
)

print(json.dumps(response.json(), indent=4))

問題がなければ、コンテナイメージの作成が開始され、コンテナのデプロイがされます。ゆっくり待ちましょう。

Step5. モデルを試してみる

Rerankするための仮データを準備します。
なんらかのRetrieverから、queryに関連する文書を10個取得してきたという想定です。

import pandas as pd

query = "この契約において知的財産権はどのような扱いなのか?"

raw_results = [
    "納入物の改良・改変をはじめとして、あらゆる使用(利用)態様を含む 。また、本契約\nにおいて「知的財産権」とは、知的財産基本法第2条第2項所定の知的財産権をいい、\n知的財産権を受ける権利及びノウハウその他の秘密情報を含む。  \n2 乙は、 納入物に第三者の知的財産権を利用する場合には、 第1条第2項 の規定に従い、\n乙の費用及び責任において当該第三者から本契約の履行及び本契約終了後の甲による",
    "(契約保証金)  \n第3条 甲は、本契約に係る乙が納付すべき契約保証金の納付を全額免除する。  \n \n (知的財産 権の帰属及び 使用) \n第4条 本契約の締結時に乙が既に所有又は管理していた 知的財産権(以下「 乙知的財産\n権」という。)を 乙が納入物に使用した場合には、甲は、当該乙知的財産権を、仕様書\n記載の「目的」のため、仕様書の「納入物」の項 に記載した利用方法に従い、本契約終",
    "が所有し、又は管理する知的財産権の実施許諾や動産・不動産の使用許可の取得等を含\nむ。)が必要な場合には乙の費用及び責任で行うものとする。 甲の指示により、委託者\n名を明示して業務を行う場合も同様とする。  \n3 甲は、委託業務及び納入物に関して、約定の委託金額以外の支払義務を負わない。 本\n契約終了後の納入物の利用についても同様とする。 委託金額には委託業務の遂行に必要",
    "新規知的財産権は 約定の委託金額以外の追加支払なしに、納入物の引渡しと同時に乙か\nら甲に譲渡され、甲単独に帰属する。  \n5 前項の規定にかかわらず、著作権等について は第28条の定めに従う。  \n6 乙は、本契約終了後であっても、知的財産権の取 扱いに関する本契約 の約定を 自ら遵\n守し、及び第7条第1項 の再委託先に遵守させ ることを 約束する。",
    "自体を含 む。)に関 して著作者人格権を行使しないことに同意する。また、乙は、当該\n著作物の著作者が乙以外の 者であるときは、当該著作者が著作者人格権を行使しないよ\nうに必要な措置をとるものとす る。 \n4 乙は、 本条及び知的財産権の帰属等に関する本契約及び仕様書の約定を 遵守するため、\n必要な範囲で職 務発明や著作権に 関する管理規程その他の社内規程を整備 するととも",
    "乙の費用及び責任において当該第三者から本契約の履行及び本契約終了後の甲による\n納入物の利用 に必要な書面の許諾を得なければならない。なお、第三者より当該許諾に\n条件を付された場合には(以下「第三者の許諾条件」という。)、乙は、納入物に第三\n者の知的財産権を利用する前に、甲に対して第三者の 許諾条件を書面で速やかに通知し",
    "甲に譲渡され、甲単独に帰属 する。乙は、 甲が求める場合には、本項に定める著作権の\n譲渡証の作 成等、譲渡を証する書面の作成に協力しなければなら ない。  \n2 本契約締結日現在乙、乙以外の委託事業参加者又は第三者の権利対象となる著作 物が\n納入物に含まれている場合であっても、甲は、納入物の利用のため、本契約期間中 及び",
    "しに、納入物の利用 に必要な範囲で、前項の第三者の知的財産権を 自由かつ対価の追加\n支払なしに 使用し、又は第三者に使用させることができる。  \n4 委託業務の遂行中に納入物に関して乙(甲の同意を得て一部を再委託する場合は再委\n託先を含む。) が新たに知的財産権(以下 「新規知的財産権」という。)を取得した場\n合には、乙は、その詳細を書面にしたものを納入物に添付して甲に提出するものとする。",
    "り本契約が終了した後で あっても、なおその効力を有する。  \n \n(著作権等の帰属)  \n第28条  納入物に係る 著作権(著 作権法第2 7条及び第28条の権利を含む。ただし、\n本契約締結日現 在、乙、乙以外 の委託事業参加 者又は第三者の権利 対象となっているも\nのを除く。以下同じ。)は、委 託金額以外の追加支払なしに、その発生と同時に乙から",
    "なければならない。甲は、当該第三者の許諾条件に同意できない場合には、本契約の 解\n約又は変更を含め、乙に対して協議を求めることができる。甲が当該条件に同意した場\n合、乙は、委託業務の遂行及び納入物の作成に 当たって第三者の許諾条件を遵守するこ\nとにつき全責任を負う。   \n3 甲は、第三者の許諾条件を遵守することを条件として、 本契約終了後も 期間の制限な",
]

pdf = pd.DataFrame(raw_results, columns=["document"]).assign(query=query)
data_body = pdf.to_dict(orient="split")

では、作成したエンドポイントにクエリを発行します。

import mlflow.deployments
import json

endpoint_name = "ragatouille-reranker-endpoint"
client = mlflow.deployments.get_deploy_client("databricks")

# Rerank実行
response = client.predict(endpoint=endpoint_name, inputs={"dataframe_split": data_body})

pd.DataFrame(response["predictions"])

image.png

結果として、入力パラメータとして入れたクエリと文書について、rerankされた一覧を得ることができました。

まとめ

RAGにおいて、特に複数種類のRetriverから関連文書を取得する場合、Rerankによって必要文書の再優先度付け・絞り込みを行うことは、検索精度を上げるために有効なことが多いです。
その場合、独立したAPIサーバとして存在しておくと、RAGのパイプラインを構築する上で非常に便利です。

Databricksでは、必要なモデルに合わせて、簡単にプロダクショングレードでのサービングができます。
やりすぎると逆に高コストになりそうですが、うまくデザインするとコストを最小化して必要なサービスを運用し続けられる環境を構築できそうです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0