2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Databricks Mosaic AI Model Servingで高性能な日本語埋め込みモデルのエンドポイント作成

Posted at

導入

ここ最近PKSHA Technologies様や名古屋大学様より高性能な日本語埋め込みモデルやそれらを利用したRerankerが次々と公開されています。

以下のように既にどちらも試されておられる方もいらっしゃいます。
詳細はこちらで確認いただくのがよいと思います。

軽量・高性能で日本語に強い埋め込みモデルは一般的なRAG利用において非常に重要です。
せっかくなので、これらをDatabricks Mosaic AI Model Servingを使ってAPIエンドポイントを作成するやり方をまとめてみます。
簡易的な実装なので、改善点は多数ありますが。

構築はDatabricks on AWS上で行いました。ノートブック上のクラスタはサーバレスクラスタで動作します。

上記、様々なモデルが公開されていますが、この記事ではPKSHA Technologies様のRoSEtta-base-jaのエンドポイントを公開してみます。
(設定を一部変更することでpkshatech/GLuCoSE-base-ja-v2やcl-nagoya/ruri-largeが同様に動作することも確認済み)

埋め込みモデル:RoSEtta-base-jaとは

公式記事より。

 LLMと検索拡張生成(Retrieval-Augmented Generation、以下RAG)の活用において、長文を含む多様なドキュメント処理のニーズが高まっています。しかし、現在の日本語文埋め込みモデルの多くは最大入力長系列について512トークンまでの制限があり、1024トークン以上を扱える実用的な軽量モデルが存在しませんでした。

 本研究では、長い系列を扱う場合に適切とされている相対位置埋め込み「RoPE」を取り入れたBERT、「RoFormer」に事前学習・事後学習を行い、最大1024トークンの系列を扱うことのできる日本語文埋め込みモデルを構築しました。GLuCoSE v2と同様の蒸留と検索に特化した追加学習によって、包括的な評価において先行研究に対して同等以上の性能になりました。本研究の成果であるモデルをRoSEtta(RoFormer-based Sentence Encoder through Distillation)という名称で、商用利用可能なライセンスで公開いたしました(https://huggingface.co/pkshatech/RoSEtta-base-ja)。

最大の特徴は上記記載のように入力長が1024トークンまで大きくなっていることですね。
多くの軽量埋め込みモデルは512トークンまでの入力にしか対応していないものが多いのですが、この埋め込みモデルではより多くの入力に対してベクトル化が可能です。

では、このモデルのAPIエンドポイントを作成していきます。

おおまかなステップとしては、以下のようになります。

  1. MLflow登録用のカスタムpyfuncモデルを定義
  2. MLflowにモデルをロギング
  3. Mosaic AI Model Servingでエンドポイントを作成

Step1. MLflow登録用のカスタムpyfuncモデルを定義

Databricks Mosaic AI Model Servingを利用するためにはMLflowにモデルを登録する必要があります。
MLflowは標準でSentence Transformersモデルをさくっとロギングするフレーバーが存在しているのですが、Model Serving利用時に入出力仕様の都合が悪かったので、pyfuncカスタムモデルをまずは定義します。

ノートブックを作成して、まずは関連パッケージをインストール。

%pip install -U "sentence-transformers==3.0.1" "sentencepiece==0.2.0" "fugashi==1.3.2" "unidic-lite==1.0.8"
%pip install -U "mlflow-skinny[databricks]>=2.16.0"

dbutils.library.restartPython()

fugashiunidic-liteはRoSEtta-base-janの動作には不要です。
cl-nagoya/ruri-largeを動かす際には必要になるので合わせてインストールするようにしています。

このあとの処理用に、パラメータを変数格納しておきます。
※ なお、事前に変数model_pathの場所へこちらのモデルファイルをダウンロードしてあります。

# MLflowのUnity Catalog登録時の設定
catalog='training'
schema='llm'
registered_model_name='training.llm.pkshatech-RoSEtta-base-ja'
model_path='/Volumes/training/llm/model_snapshots/models--pkshatech--RoSEtta-base-ja/'

# Model Serving Endpoint構築時の設定情報
endpoint_name='embedding-pkshatech-RoSEtta-base-ja-endpoint'
endpoint_workload_type='CPU'
endpoint_workload_size='Small'
endpoint_scale_to_zero_enabled='true'

# Sentence Transformersに設定するプロンプトテンプレート
# encode時にこのプレフィクスを追加して処理実行できる
prompt_template_query='query: '
prompt_template_document='passage: '

MLflowのアーティファクトとして保存する、Sentence Transformers用のJSON設定ファイルを作成します。
これをカスタムモデルの中で読み込んで参照します。

# MLflow用のconfigを作成
import json
from pprint import pprint

config = {
    "model_name": registered_model_name,
    "prompts": {
        "search_query": prompt_template_query,
        "search_document": prompt_template_document,
    },
    "default_prompt_name": "search_document",
}
json.dump(config, open("mlflow_model_config.json", "w"))

# 保存結果表示
pprint(json.load(open("mlflow_model_config.json")))
表示
{'default_prompt_name': 'search_document',
 'model_name': 'training.llm.pkshatech-RoSEtta-base-ja',
 'prompts': {'search_document': 'passage: ', 'search_query': 'query: '}}

次が最大のポイントである、MLflow用のpyfuncカスタムモデルを定義です。
MLflowのModels from Code機能を使って定義しています。
※ このセルを実行すると同じフォルダ内にsentence_transformers_embedding.pyが作られ、この後のMLflowのロギング時にこのファイルを指定します。

少しコード量が多いですが、主だった処理はSentence Transformersを使ってモデルを読み込み、predictメソッドでencode処理を実行しているところです。
また、Sentence Transformersのプロンプトテンプレート機能を使って、query: などのプレフィックスをencode時に付加しています。

入出力の仕様については、こちらに記載の仕様に準拠しました。

%%writefile "./sentence_transformers_embedding.py"

import json
import logging
from mlflow.models import set_model
from mlflow.pyfunc import PythonModel
from sentence_transformers import SentenceTransformer

logger = logging.getLogger(__name__)

class SentenceTransformersEmbeddingModel(PythonModel):
    def load_context(self, context):
        """Load the model from the given context."""
        
        model_path = context.artifacts["model_path"]
        config_file = context.artifacts["config_file"]

        logger.info(f"Loading model from {model_path}")
        logger.info(f"Loading config from {config_file}")

        config = json.load(open(config_file))
        prompts = config.get("prompts")
        default_prompt_name = config.get("default_prompt_name")
        model_name = config.get("model_name", model_path)

        self.model = SentenceTransformer(
            model_path,
            prompts=prompts,
            default_prompt_name=default_prompt_name,
            trust_remote_code=True,
        )
        self.model_name = model_name

        logger.info("Model and config loaded successfully")

    def predict(self, context, model_input):
        """ Predict embeddings for the given input. """
        
        if self.model is None:
            raise ValueError(
                "The model has not been loaded. "
                "Ensure that 'load_context' is properly executed."
            )

        # Pandas DataFrameの先頭データのみ利用
        dict_input = model_input.head(1).to_dict("records")[0]
        logger.info("Received input: {dict_input}")     
        texts = dict_input.get("input", [])
        prompt = dict_input.get("input_type", None)

        # encode
        logger.info(f"Received input for prediction: {texts}")
        logger.info(f"Prompt Template: {prompt}")        
        result = self.model.encode(texts, prompt=prompt)

        # usage
        input_ids = self.model.tokenize(texts)["input_ids"]
        prompt_tokens = sum([len(id) for id in input_ids])        
        logger.info(f"Prompt Tokens: {prompt_tokens}")

        return self._build_response(self.model_name, result, prompt_tokens)

    @staticmethod
    def _build_response(model_name, embeddings, prompt_tokens):

        data = [
            {"object": "embedding", "index": i, "embedding": emb.tolist()}
            for i, emb in enumerate(embeddings)
        ]

        return {
            "object": "list",
            "data": data,
            "model": model_name,
            "usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
        }

set_model(SentenceTransformersEmbeddingModel())

これで準備完了です。

Step2. MLflowにモデルをロギング

次にMLflowにStep1で定義したカスタムモデルクラスを使ってロギングします。

まずは入出力仕様をSignatureとして定義。

# 以下の仕様を参照
# https://docs.databricks.com/ja/machine-learning/model-serving/score-foundation-models.html#query-an-embedding-model

from mlflow.models import infer_signature
from mlflow.types.schema import Schema, ColSpec

input_example = {"input": ["embedding text"]}
output_example = {
    "object": "list",
    "data": [
        {
            "object": "embedding",
            "index": 0,
            "embedding": [
                2.0787181854248047,
                0.27202653884887695,
            ],
        }
    ],
    "model": "sample_embedding_model",
    "usage": {"prompt_tokens": 2, "total_tokens": 2},
}

signature = infer_signature(input_example, output_example)
# input_typeオプションパラメータを追加
signature.inputs.inputs.append(
    ColSpec(name="input_type", type="string", required=False)
)

signature

以下のような定義となります。

表示
inputs: 
  ['input': Array(string) (required), 'input_type': string (optional)]
outputs: 
  ['object': string (required), 'data': Array({embedding: Array(double) (required), index: long (required), object: string (required)}) (required), 'model': string (required), 'usage': {prompt_tokens: long (required), total_tokens: long (required)} (required)]
params: 
  None

次に依存関係の設定を作成。
必要最低限の依存関係のみ含まれるようにしています。

import mlflow

extra_pip_requirements = [
    "sentence-transformers==3.0.1",
    "sentencepiece==0.2.0",
    "fugashi==1.3.2",
    "unidic-lite==1.0.8",
]

pip_requirements = mlflow.pyfunc.get_default_pip_requirements() + extra_pip_requirements
pip_requirements

最後にlog_model関数を使って、モデルをロギング。あわせてUnity Catalogにも登録します。

import mlflow

with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        python_model="sentence_transformers_embedding.py",
        artifact_path="model",
        artifacts={
            "model_path": model_path,
            "config_file": "mlflow_model_config.json",
        },
        input_example=[input_example],
        signature=signature,
        pip_requirements=pip_requirements,
        registered_model_name=registered_model_name,
    )

Step3. Endpointの作成

Model ServingのAPIを実行してエンドポイントを作成します。
今回はCPUクラスタにしましたが、初回はだいたい10分程度で利用可能になります。

import requests
import json
from pprint import pprint
from mlflow import MlflowClient

client=MlflowClient()
versions = [mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")]

# 現在のノートブックコンテキストのAPIエンドポイントとトークンを取得
API_ROOT = (
    dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
)
API_TOKEN = (
    dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
)

data = {
    "name": endpoint_name,
    "config": {
        "served_entities": [
            {
                "entity_name": registered_model_name,
                "entity_version": versions[0],
                "workload_type": endpoint_workload_type,
                "workload_size": endpoint_workload_size,
                "scale_to_zero_enabled": endpoint_scale_to_zero_enabled,
            }
        ]
    },
}

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.post(
    url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers
)

print(json.dumps(response.json(), indent=4))

ブラウザから実行すると、ちゃんと結果が返ってきました。

image.png

おまけでLangChainを使ってエンドポイントのAPIを叩いてみます。

%pip install -U langchain_core langchain_community
%pip install -U "mlflow-skinny[databricks]>=2.16.0"

dbutils.library.restartPython()
from langchain_community.embeddings import DatabricksEmbeddings

embeddings = DatabricksEmbeddings(
    target_uri="databricks",
    endpoint="embedding-pkshatech-RoSEtta-base-ja-endpoint",
)

embeddings.embed_query("Databricksとは何?")
出力
[0.11444036662578583,
 -1.6237077713012695,
 1.0107405185699463,
 -0.20084182918071747,
 -0.2939634919166565,
(以下省略)

まとめ

LLMだけでなく高性能な日本語埋め込みモデルも次々と公開されてきており、この領域の進歩の速さや注目度の高さが伺えます。いい時代だなあ。
そしてDatabricksのModel Serving機能を使うと、割と簡単に本番想定のAPIエンドポイントを準備できるのは良いですね。

改善点としては、よりパフォーマンスを出すのであればSentence TransformersではなくてInfinityなどのより高効率なパッケージを使うなども考えられます。
あとはエラーハンドリング関連の実装など。

性能的な部分の確認はまだ出来ていないので、どのように使い分けをするべきなのか・Fine Tuningを目指すべきなのか、いじっていきたいと思います。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?