導入
ここ最近PKSHA Technologies様や名古屋大学様より高性能な日本語埋め込みモデルやそれらを利用したRerankerが次々と公開されています。
以下のように既にどちらも試されておられる方もいらっしゃいます。
詳細はこちらで確認いただくのがよいと思います。
軽量・高性能で日本語に強い埋め込みモデルは一般的なRAG利用において非常に重要です。
せっかくなので、これらをDatabricks Mosaic AI Model Servingを使ってAPIエンドポイントを作成するやり方をまとめてみます。
簡易的な実装なので、改善点は多数ありますが。
構築はDatabricks on AWS上で行いました。ノートブック上のクラスタはサーバレスクラスタで動作します。
上記、様々なモデルが公開されていますが、この記事ではPKSHA Technologies様のRoSEtta-base-jaのエンドポイントを公開してみます。
(設定を一部変更することでpkshatech/GLuCoSE-base-ja-v2やcl-nagoya/ruri-largeが同様に動作することも確認済み)
埋め込みモデル:RoSEtta-base-jaとは
公式記事より。
LLMと検索拡張生成(Retrieval-Augmented Generation、以下RAG)の活用において、長文を含む多様なドキュメント処理のニーズが高まっています。しかし、現在の日本語文埋め込みモデルの多くは最大入力長系列について512トークンまでの制限があり、1024トークン以上を扱える実用的な軽量モデルが存在しませんでした。
本研究では、長い系列を扱う場合に適切とされている相対位置埋め込み「RoPE」を取り入れたBERT、「RoFormer」に事前学習・事後学習を行い、最大1024トークンの系列を扱うことのできる日本語文埋め込みモデルを構築しました。GLuCoSE v2と同様の蒸留と検索に特化した追加学習によって、包括的な評価において先行研究に対して同等以上の性能になりました。本研究の成果であるモデルをRoSEtta(RoFormer-based Sentence Encoder through Distillation)という名称で、商用利用可能なライセンスで公開いたしました(https://huggingface.co/pkshatech/RoSEtta-base-ja)。
最大の特徴は上記記載のように入力長が1024トークンまで大きくなっていることですね。
多くの軽量埋め込みモデルは512トークンまでの入力にしか対応していないものが多いのですが、この埋め込みモデルではより多くの入力に対してベクトル化が可能です。
では、このモデルのAPIエンドポイントを作成していきます。
おおまかなステップとしては、以下のようになります。
- MLflow登録用のカスタムpyfuncモデルを定義
- MLflowにモデルをロギング
- Mosaic AI Model Servingでエンドポイントを作成
Step1. MLflow登録用のカスタムpyfuncモデルを定義
Databricks Mosaic AI Model Servingを利用するためにはMLflowにモデルを登録する必要があります。
MLflowは標準でSentence Transformersモデルをさくっとロギングするフレーバーが存在しているのですが、Model Serving利用時に入出力仕様の都合が悪かったので、pyfuncカスタムモデルをまずは定義します。
ノートブックを作成して、まずは関連パッケージをインストール。
%pip install -U "sentence-transformers==3.0.1" "sentencepiece==0.2.0" "fugashi==1.3.2" "unidic-lite==1.0.8"
%pip install -U "mlflow-skinny[databricks]>=2.16.0"
dbutils.library.restartPython()
fugashi
やunidic-lite
はRoSEtta-base-janの動作には不要です。
cl-nagoya/ruri-largeを動かす際には必要になるので合わせてインストールするようにしています。
このあとの処理用に、パラメータを変数格納しておきます。
※ なお、事前に変数model_path
の場所へこちらのモデルファイルをダウンロードしてあります。
# MLflowのUnity Catalog登録時の設定
catalog='training'
schema='llm'
registered_model_name='training.llm.pkshatech-RoSEtta-base-ja'
model_path='/Volumes/training/llm/model_snapshots/models--pkshatech--RoSEtta-base-ja/'
# Model Serving Endpoint構築時の設定情報
endpoint_name='embedding-pkshatech-RoSEtta-base-ja-endpoint'
endpoint_workload_type='CPU'
endpoint_workload_size='Small'
endpoint_scale_to_zero_enabled='true'
# Sentence Transformersに設定するプロンプトテンプレート
# encode時にこのプレフィクスを追加して処理実行できる
prompt_template_query='query: '
prompt_template_document='passage: '
MLflowのアーティファクトとして保存する、Sentence Transformers用のJSON設定ファイルを作成します。
これをカスタムモデルの中で読み込んで参照します。
# MLflow用のconfigを作成
import json
from pprint import pprint
config = {
"model_name": registered_model_name,
"prompts": {
"search_query": prompt_template_query,
"search_document": prompt_template_document,
},
"default_prompt_name": "search_document",
}
json.dump(config, open("mlflow_model_config.json", "w"))
# 保存結果表示
pprint(json.load(open("mlflow_model_config.json")))
{'default_prompt_name': 'search_document',
'model_name': 'training.llm.pkshatech-RoSEtta-base-ja',
'prompts': {'search_document': 'passage: ', 'search_query': 'query: '}}
次が最大のポイントである、MLflow用のpyfuncカスタムモデルを定義です。
MLflowのModels from Code機能を使って定義しています。
※ このセルを実行すると同じフォルダ内にsentence_transformers_embedding.py
が作られ、この後のMLflowのロギング時にこのファイルを指定します。
少しコード量が多いですが、主だった処理はSentence Transformersを使ってモデルを読み込み、predict
メソッドでencode処理を実行しているところです。
また、Sentence Transformersのプロンプトテンプレート機能を使って、query: などのプレフィックスをencode時に付加しています。
入出力の仕様については、こちらに記載の仕様に準拠しました。
%%writefile "./sentence_transformers_embedding.py"
import json
import logging
from mlflow.models import set_model
from mlflow.pyfunc import PythonModel
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
class SentenceTransformersEmbeddingModel(PythonModel):
def load_context(self, context):
"""Load the model from the given context."""
model_path = context.artifacts["model_path"]
config_file = context.artifacts["config_file"]
logger.info(f"Loading model from {model_path}")
logger.info(f"Loading config from {config_file}")
config = json.load(open(config_file))
prompts = config.get("prompts")
default_prompt_name = config.get("default_prompt_name")
model_name = config.get("model_name", model_path)
self.model = SentenceTransformer(
model_path,
prompts=prompts,
default_prompt_name=default_prompt_name,
trust_remote_code=True,
)
self.model_name = model_name
logger.info("Model and config loaded successfully")
def predict(self, context, model_input):
""" Predict embeddings for the given input. """
if self.model is None:
raise ValueError(
"The model has not been loaded. "
"Ensure that 'load_context' is properly executed."
)
# Pandas DataFrameの先頭データのみ利用
dict_input = model_input.head(1).to_dict("records")[0]
logger.info("Received input: {dict_input}")
texts = dict_input.get("input", [])
prompt = dict_input.get("input_type", None)
# encode
logger.info(f"Received input for prediction: {texts}")
logger.info(f"Prompt Template: {prompt}")
result = self.model.encode(texts, prompt=prompt)
# usage
input_ids = self.model.tokenize(texts)["input_ids"]
prompt_tokens = sum([len(id) for id in input_ids])
logger.info(f"Prompt Tokens: {prompt_tokens}")
return self._build_response(self.model_name, result, prompt_tokens)
@staticmethod
def _build_response(model_name, embeddings, prompt_tokens):
data = [
{"object": "embedding", "index": i, "embedding": emb.tolist()}
for i, emb in enumerate(embeddings)
]
return {
"object": "list",
"data": data,
"model": model_name,
"usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
}
set_model(SentenceTransformersEmbeddingModel())
これで準備完了です。
Step2. MLflowにモデルをロギング
次にMLflowにStep1で定義したカスタムモデルクラスを使ってロギングします。
まずは入出力仕様をSignatureとして定義。
# 以下の仕様を参照
# https://docs.databricks.com/ja/machine-learning/model-serving/score-foundation-models.html#query-an-embedding-model
from mlflow.models import infer_signature
from mlflow.types.schema import Schema, ColSpec
input_example = {"input": ["embedding text"]}
output_example = {
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
2.0787181854248047,
0.27202653884887695,
],
}
],
"model": "sample_embedding_model",
"usage": {"prompt_tokens": 2, "total_tokens": 2},
}
signature = infer_signature(input_example, output_example)
# input_typeオプションパラメータを追加
signature.inputs.inputs.append(
ColSpec(name="input_type", type="string", required=False)
)
signature
以下のような定義となります。
inputs:
['input': Array(string) (required), 'input_type': string (optional)]
outputs:
['object': string (required), 'data': Array({embedding: Array(double) (required), index: long (required), object: string (required)}) (required), 'model': string (required), 'usage': {prompt_tokens: long (required), total_tokens: long (required)} (required)]
params:
None
次に依存関係の設定を作成。
必要最低限の依存関係のみ含まれるようにしています。
import mlflow
extra_pip_requirements = [
"sentence-transformers==3.0.1",
"sentencepiece==0.2.0",
"fugashi==1.3.2",
"unidic-lite==1.0.8",
]
pip_requirements = mlflow.pyfunc.get_default_pip_requirements() + extra_pip_requirements
pip_requirements
最後にlog_model
関数を使って、モデルをロギング。あわせてUnity Catalogにも登録します。
import mlflow
with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(
python_model="sentence_transformers_embedding.py",
artifact_path="model",
artifacts={
"model_path": model_path,
"config_file": "mlflow_model_config.json",
},
input_example=[input_example],
signature=signature,
pip_requirements=pip_requirements,
registered_model_name=registered_model_name,
)
Step3. Endpointの作成
Model ServingのAPIを実行してエンドポイントを作成します。
今回はCPUクラスタにしましたが、初回はだいたい10分程度で利用可能になります。
import requests
import json
from pprint import pprint
from mlflow import MlflowClient
client=MlflowClient()
versions = [mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")]
# 現在のノートブックコンテキストのAPIエンドポイントとトークンを取得
API_ROOT = (
dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
)
API_TOKEN = (
dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
)
data = {
"name": endpoint_name,
"config": {
"served_entities": [
{
"entity_name": registered_model_name,
"entity_version": versions[0],
"workload_type": endpoint_workload_type,
"workload_size": endpoint_workload_size,
"scale_to_zero_enabled": endpoint_scale_to_zero_enabled,
}
]
},
}
headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}
response = requests.post(
url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers
)
print(json.dumps(response.json(), indent=4))
ブラウザから実行すると、ちゃんと結果が返ってきました。
おまけでLangChainを使ってエンドポイントのAPIを叩いてみます。
%pip install -U langchain_core langchain_community
%pip install -U "mlflow-skinny[databricks]>=2.16.0"
dbutils.library.restartPython()
from langchain_community.embeddings import DatabricksEmbeddings
embeddings = DatabricksEmbeddings(
target_uri="databricks",
endpoint="embedding-pkshatech-RoSEtta-base-ja-endpoint",
)
embeddings.embed_query("Databricksとは何?")
[0.11444036662578583,
-1.6237077713012695,
1.0107405185699463,
-0.20084182918071747,
-0.2939634919166565,
(以下省略)
まとめ
LLMだけでなく高性能な日本語埋め込みモデルも次々と公開されてきており、この領域の進歩の速さや注目度の高さが伺えます。いい時代だなあ。
そしてDatabricksのModel Serving機能を使うと、割と簡単に本番想定のAPIエンドポイントを準備できるのは良いですね。
改善点としては、よりパフォーマンスを出すのであればSentence TransformersではなくてInfinityなどのより高効率なパッケージを使うなども考えられます。
あとはエラーハンドリング関連の実装など。
性能的な部分の確認はまだ出来ていないので、どのように使い分けをするべきなのか・Fine Tuningを目指すべきなのか、いじっていきたいと思います。