0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

EmbeddingGemmaをDatabricks上でサーブする

Posted at

導入

Google社が埋め込みモデルEmbbeddingGemmaを公開しました。

特徴を上記より抜粋。

EmbeddingGemma は、Gemma 3 に基づく 3 億 800 万のパラメータを持つ多言語テキスト エンベディング モデルです。スマートフォン、ノートパソコン、タブレットなどの日常的なデバイス向けに最適化されています。このモデルは、情報検索、セマンティック類似性検索、分類、クラスタリングなどのダウンストリーム タスクで使用するテキストの数値表現を生成します。

EmbeddingGemma には次の主な特徴があります。

多言語サポート: 100 以上の言語でトレーニングされた幅広い言語データ理解。
柔軟な出力ディメンション: Matryoshka Representation Learning(MRL)を使用して、速度とストレージのトレードオフのために出力ディメンションを 768 から 128 までカスタマイズします。
2K トークン コンテキスト: ハードウェア上でテキストデータやドキュメントを直接処理するための実質的な入力コンテキスト。
ストレージ効率が高い: 量子化により 200 MB 未満の RAM で実行
低レイテンシ: EdgeTPU で 22 ミリ秒未満の生成エンベディングにより、高速でスムーズなアプリケーションを実現します。
オフラインで安全: ドキュメントのエンベディングをハードウェア上で直接生成します。インターネット接続なしで動作し、機密データを安全に保ちます。

300Mという比較的小さいパラメータ数ですが高い性能を示しているモデルです。また、MRL(Matryoshka Representation Learning)を介して、複数の次元数に対応しています。

このサイズならCPUでもそこそこの速さで推論できるんじゃないかと思い、DatabricksのModel Serving機能を使ってCPUインスタンスでサーブしてみた、という趣旨の記事になります。

検証はDatabricks on AWS上で行いました。
おそらくFree Editionでもできるのではないかと思います。

EmbeddingGemmmaをサーブする

試す場合は、事前準備として以下のhuggingfaceリポジトリからモデルの重みファイルをUnity Catalog Volumeにダウンロードしておいてください。

https://huggingface.co/google/embeddinggemma-300m

まずはノートブックを作成し、必要なパッケージをインストールします。
今回はSentence Transformersを使って埋め込みモデルを動かすことにします。

%pip install sentence_transformers==5.1.0 mlflow==3.3.2

%restart_python

埋め込みモデルを利用するためのMLflowカスタムクラスを定義します。
OpenAIのEmbbeding APIと近しい入出力仕様にしています。

%%writefile embedding_model.py

import mlflow
from mlflow.models import set_model
from mlflow.pyfunc import PythonModel
from mlflow.exceptions import MlflowException
from sentence_transformers import SentenceTransformer
import pydantic
from typing import Union, Any
import pandas as pd

def postprocess_output_for_llm_v1_embedding_task(
    input_prompts: list[str],
    output_tensors: list[list[float]],
    tokenizer,
) -> dict:
    """
    埋め込みタスクの出力を変換します。

    Args:
        input_prompts (list[str]): 入力プロンプトのリスト。
        output_tensors (list[list[float]]): 出力テンソルのリスト。
        tokenizer: プロンプトをエンコードするために使用されるトークナイザー。

    Returns:
        dict: 処理された出力を含む辞書。
    """
    prompt_tokens = sum(len(tokenizer.encode(prompt)) for prompt in input_prompts)
    return {
            "object": "list",
            "data": [
                {
                    "object": "embedding",
                    "index": i,
                    "embedding": tensor,
                }
                for i, tensor in enumerate(output_tensors)
            ],
            "usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
        }


class EmbeddingModel(PythonModel):
    def __init__(self, model=None):
        """
        EmbeddingModelを初期化します。

        Args:
            model: 埋め込みに使用するモデル。
        """
        self.model = model

    def load_context(self, context):
        """
        モデルコンテキストをロードします。

        Args:
            context: モデルアーティファクトを含むコンテキスト。
        """
        self.model = SentenceTransformer(context.artifacts["model_path"])

    def predict(
        self, context, model_input, params: dict = None
    ) -> dict:
        """
        指定された入力に対して埋め込みを予測します。

        Args:
            context: モデルアーティファクトを含むコンテキスト。
            model_input (list[EmbeddingModelInput]): 入力文を含むリスト。
            params (dict, optional): モデルの追加パラメータ。

        Returns:
            dict: 埋め込みを含む辞書。
        """

        if isinstance(model_input, pd.DataFrame):
            pass
        elif isinstance(model_input, list) or isinstance(model_input, dict):
            model_input = pd.DataFrame(model_input)
        else:
            raise MlflowException.invalid_parameter_value(
            "Received invalid parameter value for `model_input` argument"
        )

        sentences = model_input.explode('input')["input"]
        output_data = []
        if params:
            try:
                # パラメータ設定がある&単一文章の場合、クエリとして対応
                if "type" in params:
                    type = params.pop("type")
                    if type == "query" and len(sentences) == 1:
                        output_data = [self.model.encode_query(sentences[0], **params)]
                else:
                    output_data = self.model.encode_document(sentences, **params)
            except TypeError as e:
                raise MlflowException.invalid_parameter_value(
                    "Received invalid parameter value for `params` argument"
                ) from e
        else:
            output_data = self.model.encode_document(sentences)

        output_data = postprocess_output_for_llm_v1_embedding_task(
            sentences, output_data, self.model.tokenizer
        )
        return output_data


model = EmbeddingModel()
set_model(model)

定義したカスタムクラスを用いて、MLflowのModel Registryに埋め込みモデルを登録します。
こちらもOpenAIのEmbedding APIと同様の入出力仕様となるようにsignatureを設定しています。

import mlflow
from mlflow.pyfunc import PythonModel
from mlflow.types.schema import Schema, ColSpec, AnyType
from mlflow.models import ModelSignature
from mlflow.types.llm import (
    EMBEDDING_MODEL_OUTPUT_SCHEMA,
)

# EmbeddingGemmaの重みをダウンロードした場所を指定
model_path = (
    "/Volumes/training/llm/model_snapshots/models--google--embeddinggemma-300m/"
)

# 入力スキーマ/例
input_schema = Schema([
    ColSpec(type=AnyType(), name="input"),  # 入力は制限をかけない
])
input_example = [{"input": ["What is Databricks"]}]

# Unity Catalog上のモデル登録先
registered_model_name = "training.llm.embeddinggemma"

# MLflowモデル登録
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        name="embedding_model",
        python_model="embedding_model.py",
        artifacts={"model_path": model_path},
        input_example=input_example,
        signature=ModelSignature(
            inputs=input_schema,
            outputs=EMBEDDING_MODEL_OUTPUT_SCHEMA,
        ),
        registered_model_name=registered_model_name,
    )

最後に登録したモデルを利用して、Model Servingにエンドポイントを作成します。
(デプロイ完了まで数十分程度かかります)

from mlflow.deployments import get_deploy_client

client = get_deploy_client("databricks")

endpoint = client.create_endpoint(
    config={
        "name":"embedding_gemma_endpoint",
        "config": {
            "served_entities": [
                {
                    "name": "emb-entity",
                    "entity_name": registered_model_name,
                    "entity_version": "1",
                    "workload_size": "Small",
                    "scale_to_zero_enabled": True,
                }
            ],
        },
    },
)

endpoint

これで準備は完了です。

使ってみる

別にノートブックを作成し、MLflowやLangChainから利用してみます。
まず、パッケージをインストール。

%pip install databricks-langchain mlflow

%restart_python

まずはMLflowのクライアントから利用してみます。

import mlflow.deployments

client = mlflow.deployments.get_deploy_client("databricks")

embeddings_response = client.predict(
    endpoint="embedding_gemma_endpoint",
    inputs={"input": "Here is some text to embed"},
)

embeddings_response

以下のように出力されます。

出力結果
{'object': 'list',
 'data': [{'object': 'embedding',
   'index': 0,
   'embedding': [-0.12576764822006226,
    -0.002920572878792882,
    0.0012626828392967582,
    0.016500266268849373,
    (中略)
    -0.019344018772244453]}],
 'usage': {'prompt_tokens': 8, 'total_tokens': 8}}

つぎにLangChainから。
複数の文字列をベクトル変換します。

import mlflow
from databricks_langchain import DatabricksEmbeddings

mlflow.langchain.autolog()

endpoint_name = "embedding_gemma_endpoint"
embeddings = DatabricksEmbeddings(
    endpoint=endpoint_name,
    query_params={"type": "query"},
)

embeddings.embed_documents(["Hello!", "World!"])

こちらも問題なく出力されました。

出力結果
[[-0.1302994340658188,
  -0.010217982344329357,
  0.02192031778395176,
  (中略)
  -0.020601240918040276,
  -0.0017430599546059966]]

正確な時間を測っていませんが、少ない件数であればそれなりに高速なレスポンスを得られます。

おわりに

EmbeddingGemmaをCPUインスタンスのエンドポイントで動かしてみました。
開発やテストで埋め込みモデルが必要なときに使えそうかなという印象です。

Databricksは日本語対応の埋め込み基盤モデルがPay-per-tokenといった形で公開されていないため、EmbeddingGemmaは軽量で使い勝手が良いように思います。プロビジョニングされた基盤モデルとして公開してくれないかなあ。。。

また、llama.cppを使って量子化したファイルでサーブするとさらに高速に動かせる気がします。余裕があればやってみようと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?