2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

導入

先日、CPUでも高速動作する日本語StaticEmbeddingモデルが公開されました。

記事に記載されているように、TransformerモデルではないEmbeddingモデルであり、ベンチマーク比較ではintfloat/multilingual-e5-small比較で85%のスコア、CPU動作時の速度が126倍高速(!)という驚くべきパフォーマンスです。

最近のEmbeddingモデルは肥大化の傾向で、もはやGPUを使うことが前提という印象でしたが、CPUで高速に動作するのであれば多少性能を犠牲にしても使えるユースケースが非常に多いのではないかと思います。

というわけで、Databricksを使って試してみます。
実施環境はDatabricks on AWS、クラスタはCPUで動かすのでサーバレスを使っています。

Step1. とりあえず動かしてみる

まずは記事内のサンプルをちょっと変えて動かしてみます。

必要なパッケージをインストール。MLflowも合わせて最新化しておきます。

%pip install "sentence-transformers==3.4.0" "mlflow-skinny[databricks]==2.20.0"

dbutils.library.restartPython() 

では、記事内のサンプルコードをいじって実行します。

from sentence_transformers import SentenceTransformer

model_name = "hotchpotch/static-embedding-japanese"
model = SentenceTransformer(model_name, device="cpu")

query = "使いやすいAI+データプラットフォームが欲しい"
docs = [
    "AIとデータプラットフォームを活用して、業務効率を大幅に向上させる方法を探しています。",
    "最新のAI技術を駆使したデータ分析プラットフォームが欲しいです。",
    "データの可視化とAIによる予測分析が一体となったプラットフォームを導入したい。",
    "AIを活用したデータプラットフォームで、ビジネスインサイトを得ることが目標です。",
    "使いやすいAIツールとデータプラットフォームを組み合わせて、業務改善を図りたい。",
    "AIとデータプラットフォームの連携で、迅速な意思決定をサポートするシステムが必要です。",
    "データの収集から分析までを一貫して行えるAIプラットフォームを探しています。",
    "AI技術を取り入れたデータプラットフォームで、競争力を高めたいと考えています。",
    "使いやすいインターフェースを持つAIデータプラットフォームが求められています。",
    "AIとデータプラットフォームの統合で、業務の効率化と精度向上を目指しています。",
]

embeddings = model.encode([query] + docs)
print(embeddings.shape)
similarities = model.similarity(embeddings[0], embeddings[1:])
for i, similarity in enumerate(similarities[0].tolist()):
    print(f"{similarity:.04f}: {docs[i]}")
実行結果
(11, 1024)
0.6675: AIとデータプラットフォームを活用して、業務効率を大幅に向上させる方法を探しています。
0.7444: 最新のAI技術を駆使したデータ分析プラットフォームが欲しいです。
0.6356: データの可視化とAIによる予測分析が一体となったプラットフォームを導入したい。
0.6183: AIを活用したデータプラットフォームで、ビジネスインサイトを得ることが目標です。
0.7177: 使いやすいAIツールとデータプラットフォームを組み合わせて、業務改善を図りたい。
0.6361: AIとデータプラットフォームの連携で、迅速な意思決定をサポートするシステムが必要です。
0.6491: データの収集から分析までを一貫して行えるAIプラットフォームを探しています。
0.5920: AI技術を取り入れたデータプラットフォームで、競争力を高めたいと考えています。
0.8353: 使いやすいインターフェースを持つAIデータプラットフォームが求められています。
0.6374: AIとデータプラットフォームの統合で、業務の効率化と精度向上を目指しています。

結果は1秒程度で返ってきます。CPU動作なのに類似度計算が早い・・!
SentenceTransformerで簡単に処理が書けるので便利。

Step2. MLflowと組み合わせて使う

Databricksで使うならMLflowと組み合わせないと・・・!という使命感の元、MLflowに登録して使ってみます。

実のところ、MLflowはSentenceTransformerのフレーバーがあるので以下のサイトにあるように書くと簡単にモデルのロギングや利用ができます。

ただ、現状のMLFlow(Version 2.20.0)だと、SentenceTransformerのバージョンが2.2.0-2.4.0の範囲で正式対応している模様。バージョン範囲外のSentenceTransformerを使うので、フレーバーの標準機能を使わずにPyfuncカスタムモデルでモデルをロギング・使用してみます。

まず、モデルをDatabricks上に保存。

from sentence_transformers import SentenceTransformer

# モデルの保存
model_name = "hotchpotch/static-embedding-japanese"
model = SentenceTransformer(model_name)
model_path = "./sentence_transformers_model"
model.save(model_path)

次にSentenceTransformerをラップするpyfuncカスタムモデルを定義し、MLflowでロギングします。
せっかくなのでOpenAIのEmbedding仕様と同様の入出力をカバーするようにしました。
(MLflow SentenceTransformerのフレーバーを使えば標準でOpenAI Embedding仕様をカバーできるのですが、最新バージョンのSentenceTransformersを使うとpredict実行時にエラーになったので使用を断念)

import mlflow
from mlflow.pyfunc import PythonModel
import pandas as pd
from mlflow.models import ModelSignature
from mlflow.types.llm import (
    EMBEDDING_MODEL_INPUT_SCHEMA,
    EMBEDDING_MODEL_OUTPUT_SCHEMA,
)
from mlflow.exceptions import MlflowException

def postprocess_output_for_llm_v1_embedding_task(
    input_prompts: list[str],
    output_tensors: list[list[float]],
    tokenizer,
):
    
    prompt_tokens = sum(
        len(tokenizer.encode(prompt).ids) for prompt in input_prompts
    )
    return {
        "object": "list",
        "data": [
            {
                "object": "embedding",
                "index": i,
                "embedding": tensor,
            }
            for i, tensor in enumerate(output_tensors)
        ],
        "usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
    }


class StaticEmbeddingModel(PythonModel):
    def __init__(self):
        self.model = None

    def load_context(self, context):
        """コンテキスト情報をロード"""
        self.model = SentenceTransformer.load(context.artifacts["model_path"])

    def predict(self, context, model_input, params=None):
        """埋め込みモデルを使ってベクトル計算した結果を返す"""

        sentences = model_input["input"]
        if params:
            try:
                output_data = self.model.encode(sentences, **params)
            except TypeError as e:
                raise MlflowException.invalid_parameter_value(
                    "Received invalid parameter value for `params` argument"
                ) from e
        else:
            output_data = self.model.encode(sentences)

        output_data = postprocess_output_for_llm_v1_embedding_task(
            sentences, output_data, self.model.tokenizer
        )
        return output_data

# シグネチャの設定。OpenAI Embedding仕様とコンパチにする
signature = ModelSignature(
    inputs=EMBEDDING_MODEL_INPUT_SCHEMA, outputs=EMBEDDING_MODEL_OUTPUT_SCHEMA
)

# モデルの保管場所
model_path = "./sentence_transformers_model"

# MLflowモデル登録
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="embedding_model",
        python_model=StaticEmbeddingModel(),
        artifacts={"model_path": model_path},
        signature=signature,
    )

ロギングしたモデルを使って埋め込みを実行してみます。

docs = [
    "Databricksはデータエンジニアリングとデータサイエンスのための統合プラットフォームです。",
    "Databricksを使用すると、データの処理と分析が簡単になります。",
    "DatabricksはApache Sparkをベースにしたクラウドサービスです。",
    "Databricksは大規模なデータ処理に最適です。",
    "Databricksは機械学習モデルのトレーニングとデプロイをサポートします。",
    "Databricksはデータの可視化と共有を容易にします。",
    "Databricksはデータパイプラインの構築を簡素化します。",
    "Databricksはリアルタイムデータ分析を可能にします。",
    "Databricksはデータのセキュリティとガバナンスを提供します。",
    "Databricksはチームでのコラボレーションを促進します。"
]
df = pd.DataFrame(
    {
        "input": docs,
    }
)

loaded = mlflow.pyfunc.load_model(model_info.model_uri)
result = loaded.predict(df)
display(result)
実行結果
{'object': 'list',
 'data': [{'object': 'embedding',
   'index': 0,
   'embedding': array([-0.08232033,  2.028538  ,  1.2824306 , ..., -2.1574838 ,
          -1.1925044 , -0.07672863], dtype=float32)},
  {'object': 'embedding',
   'index': 1,
   'embedding': array([-0.31956077,  0.7703177 ,  0.12106168, ..., -0.6689819 ,
          -0.72449917, -1.0694993 ], dtype=float32)},
  {'object': 'embedding',
   'index': 2,
   'embedding': array([-1.1472858 ,  2.6340609 ,  0.23602125, ..., -1.1709659 ,
          -1.2186396 , -0.86111856], dtype=float32)},
  {'object': 'embedding',
   'index': 3,
   'embedding': array([-0.09853197,  1.6103554 ,  0.17753945, ..., -0.46319506,
          -0.9745547 , -0.3976228 ], dtype=float32)},
  {'object': 'embedding',
   'index': 4,
   'embedding': array([-0.26529256,  0.9429942 , -0.06867654, ..., -0.08740149,
           0.38548303,  0.27280912], dtype=float32)},
  {'object': 'embedding',
   'index': 5,
   'embedding': array([-0.18773255,  1.2199514 ,  0.71247834, ..., -2.6223009 ,
          -1.8244014 , -1.4321138 ], dtype=float32)},
  {'object': 'embedding',
   'index': 6,
   'embedding': array([-0.93844414,  1.6274796 , -1.3148386 , ..., -2.1880143 ,
          -1.8378772 , -1.2620323 ], dtype=float32)},
  {'object': 'embedding',
   'index': 7,
   'embedding': array([ 0.24410263,  1.2611744 , -0.2576989 , ..., -0.518362  ,
          -1.745497  , -0.7675792 ], dtype=float32)},
  {'object': 'embedding',
   'index': 8,
   'embedding': array([-0.24479014,  0.71801883, -0.34102565, ..., -0.6883325 ,
          -2.2200572 , -0.46863177], dtype=float32)},
  {'object': 'embedding',
   'index': 9,
   'embedding': array([-0.87576395,  2.3160427 , -0.2711031 , ..., -2.1476026 ,
          -0.99900407, -1.3423741 ], dtype=float32)}],
 'usage': {'prompt_tokens': 183, 'total_tokens': 183}}

是非実行してみ欲しいのですが、CPUでもかなり高速に動作します。すごいな。。。
まだ試してませんが、Databricks Mosaic AI Model Servingでデプロイすると超軽量な埋め込みモデルサーバとして利用できるのではないでしょうか。

まとめ

日本語StaticEmbeddingをDatabricsk+MLflowも組み合わせて試してみました。

精度面の評価は全然していないのですが、ベンチマーク通りならユースケースによっては十分な精度を備えていると思いますし、冒頭記載したようにいろんな利用ができると思います。
次元削減にも対応しているので、精度とのトレードオフになりますがより軽量を活かした使い方もできそうです。

元記事内でも記載されていますが、正直、非Transformerモデルでこれだけの性能が出るのは個人的にも驚きです。BM25の置き換えに確かに使えるのかも。もう少し深堀ってみます。


私事ですが、ようやくいろいろ落ち着いてきたので、投稿ペースを頑張って戻していきたいと思います。
ネタがとても溜まっている。。。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?