LoginSignup
1
1

MLflowのStreamable Python ModelsでLLMのストリーミング出力をやってみる

Last updated at Posted at 2024-05-21

導入

MLflowのv2.13.0がリリースされていました。

※ 2024/5/21現在、こちらにはまだ未記載。ドキュメント類を準備中なんでしょうか。

主な機能としてStreamable Python Modelsが挙げられています。
これはLLMを対象として、predict_streamメソッドを実装することでストリーミング出力できるようになる拡張です。

実際のところ、v2.12.2から存在していた(LangChain flavorだと2.12.1から?)と思うのですが、そういえば試していなかったので、こちらを実装してストリーミング出力を試してみます。

検証はDatabricks on AWS、DBRは15.1ML、インスタンスタイプはg4dn.xlargeです。
推論エンジンにはExLlamaV2を使いますが、transformersやLlama.cppなど他の(ストリーミング出力に対応した)パッケージでも問題ないと思います。

Step1. パッケージインストール

torchとexllamav2パッケージ、そしてmlflowを最新化します。

%pip install torch==2.3.0 --index-url https://download.pytorch.org/whl/cu121

%pip install -U flash-attn --no-build-isolation
%pip install https://github.com/turboderp/exllamav2/releases/download/v0.0.21/exllamav2-0.0.21+cu121-cp311-cp311-linux_x86_64.whl

%pip install "mlflow-skinny[databricks]>=2.13.0"

dbutils.library.restartPython()

Step2. Mlflow Pyfunc Customモデルを作成

MLflow pyfuncカスタムモデルを定義します。
mlflow.pyfunc.PythonModelを継承してload_contextpredictメソッドを実装するのは従来通りですが、さらにpredict_streamメソッドを実装します。

predict_streamメソッドはpythonのgeneratorを返す形で定義する必要があるようです。
今回は生成したトークンをyieldで返すよう実装しました。
(パラメータ処理などを入れているため長くなっていますが、重要なのはpredict_streamのみです)

from typing import List
import uuid

import mlflow
from mlflow.types.llm import ChatMessage

# Define a custom PythonModel
class ExLlamaV2CustomModel(mlflow.pyfunc.PythonModel):
    def __init__(self, prompt_map, pre_prompt="", post_prompt=""):        
        self.prompt_map = prompt_map
        self.pre_prompt = pre_prompt
        self.post_prompt = post_prompt
        self._model = None
        self._cache = None
        self._tokenizer = None

    # プロンプトを整形
    def format_messages(self, messages):
        prompt = self.pre_prompt
        for mes in messages:
            template = self.prompt_map.get(mes.role)
            if template:
                prompt += template.format(mes.content)

        prompt += self.post_prompt

        return prompt

    def load_context(self, context):
        from exllamav2 import (
            ExLlamaV2,
            ExLlamaV2Config,
            ExLlamaV2Cache_Q4,
            ExLlamaV2Tokenizer,
        )
        from exllamav2.generator import ExLlamaV2BaseGenerator

        model_directory = context.artifacts["llm-model"]

        config = ExLlamaV2Config(model_directory)

        model = ExLlamaV2(config)
        print("Loading model: " + model_directory)

        # Q4 KV-cache
        cache = ExLlamaV2Cache_Q4(
            model,
            lazy=True,
            max_seq_len=8192,
        ) 
        model.load_autosplit(cache)

        tokenizer = ExLlamaV2Tokenizer(config)

        self._model = model
        self._cache = cache
        self._tokenizer = tokenizer

    def predict(self, context, model_input, params=None):
        """ 従来の予測インターフェース """

        from exllamav2.generator import ExLlamaV2BaseGenerator

        # プロンプトの作成
        prompt = self._build_prompt_from_model_input(model_input["prompt"], params)[0]

        # サンプリングパラメータの作成
        settings, max_tokens, _ = self._build_sampling_params(params)

        generator = ExLlamaV2BaseGenerator(self._model, self._cache, self._tokenizer)
        output = generator.generate_simple(
            prompt,
            settings,
            max_tokens,
            seed=1234,
            add_bos=True,
            completion_only=True,
        )

        return output
    
    def predict_stream(self, context, model_input, params=None):
        """ 今回の要点。ストリーミングで出力する """

        from exllamav2.generator import (
            ExLlamaV2StreamingGenerator,
        )

        # プロンプトの作成
        prompt = self._build_prompt_from_model_input([model_input["prompt"]], params)[0]

        # サンプリングパラメータの作成
        settings, max_tokens, stop = self._build_sampling_params(params)

        generator = ExLlamaV2StreamingGenerator(self._model, self._cache, self._tokenizer)

        input_ids = self._tokenizer.encode(prompt, add_bos = True)

        generator.set_stop_conditions(stop)
        generator.begin_stream_ex(input_ids, settings)

        generated_tokens = 0
        while True:
            res = generator.stream_ex()
            chunk = res["chunk"]
            eos = res["eos"]

            generated_tokens += 1
            # 出力できるのは文字列 or 辞書型のみ
            yield {"content": chunk}

            if eos or generated_tokens == max_tokens: break
    

    def _build_prompt_from_model_input(self, model_input, params):
        """プロンプトを構築する"""

        DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
        system_prompt = (
            params.get("system_prompt", DEFAULT_SYSTEM_PROMPT)
            if params
            else DEFAULT_SYSTEM_PROMPT
        )
        prompts = [
            self.format_messages(
                [
                    ChatMessage(role="system", content=system_prompt),
                    ChatMessage(role="user", content=p),
                ]
            )
            for p in model_input
        ]
        return prompts

    def _build_sampling_params(self, params):
        """サンプリングパラメータを構築する"""
        from exllamav2.generator import ExLlamaV2Sampler

        _params = params.copy() if params else {}

        # サンプリングの設定
        settings = ExLlamaV2Sampler.Settings()
        settings.temperature = _params.get("temperature", 1.0)
        settings.top_k = _params.get("top_k", 50)
        settings.top_p = _params.get("top_p", 0.8)
        settings.token_repetition_penalty =  _params.get("repetition_penalty", 1.05)
        max_tokens = _params.get("max_tokens", 10)
        stop = [self._tokenizer.single_id(_params.get("stop"))] if _params.get("stop", None) else []

        return settings, max_tokens, stop


    def __getstate__(self):
        # ExLlamaV2モデルはPickle化から除外
        state = self.__dict__.copy()
        del state["_model"]
        del state["_cache"]
        del state["_tokenizer"]
        return state

Step3. Mlflowへ登録

定義したカスタムモデルクラスを使ってMlflowへモデルをロギングします。

まずはスキーマを定義。

import numpy as np
import pandas as pd

import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types import ColSpec, DataType, ParamSchema, ParamSpec, Schema

# Define input and output schema
input_schema = Schema(
    [
        ColSpec(DataType.string, "prompt"),
    ]
)
output_schema = Schema([ColSpec(DataType.string, "output")])

parameters = ParamSchema(
    [
        ParamSpec("temperature", DataType.float, np.float32(0.1), None),
        ParamSpec("max_tokens", DataType.integer, np.int32(512), None),
        ParamSpec("top_k", DataType.integer, np.int32(50), None),
        ParamSpec("top_p", DataType.float, np.float32(0.8), None),
        ParamSpec("repetition_penalty", DataType.float, np.float32(1.0), None),
        ParamSpec("stop", DataType.string, None, None),
        ParamSpec(
            "system_prompt", DataType.string, "You are a helpful assistant.", None
        ),
    ]
)

signature = ModelSignature(
    inputs=input_schema, outputs=output_schema, params=parameters
)

# Define input example
input_example = [
    {"prompt": "What is Databricks?"},
]

モデルをロギング。
事前にダウンロードしておいた以下のモデルを使っています。

model_path = "/Volumes/training/llm/model_snapshots/models--bartowski--Meta-Llama-3-8B-Instruct-special-tokens-exl2--8_0/"
prompt_map = {
    "system": "<|start_header_id|>system<|end_header_id|>\n\n{}<|eot_id|>",
    "user": "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>",
    "assistant": "<|start_header_id|>assistant<|end_header_id|>\n\n{}<|eot_id|>",
}
post_prompt = "<|start_header_id|>assistant<|end_header_id|>\n\n"

model = ExLlamaV2CustomModel(prompt_map, post_prompt=post_prompt)

with mlflow.start_run() as run:
    _ = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=model,
        artifacts={
            "llm-model": model_path,
        },
        signature=signature,
        input_example=input_example,
        example_no_conversion=True,
        await_registration_for=1200,  # モデルサイズが大きいので長めの待ち時間にします
    )

これでモデルの準備は完了です。

Step4. ストリーミング出力を試してみる

登録したモデルをロードして使ってみましょう。

まずは、登録モデルをpyfunc.load_modelでロード。

import mlflow

# 最新のrun_idを取得
order_by = ["start_time desc"]  # 開始時間で降順にソート
runs = mlflow.search_runs(order_by=order_by, max_results=1)
latest_run_id = runs.loc[0, "run_id"]
logged_model = f'runs:/{latest_run_id}/model'

# PyFuncModelとしてモデルをロード
loaded_model = mlflow.pyfunc.load_model(logged_model)

ストリーミング出力の前に、まずは従来通りpredictで推論してみます。

loaded_model.predict([{"prompt": "what is databricks?"}], params={"max_tokens":50})
出力
'Databricks is a cloud-based data engineering platform that provides a collaborative workspace for data engineers, data scientists, and data analysts to work together on big data processing, analytics, and machine learning tasks. It was founded in 2013 by the original'

問題なく推論できていますね。

では、ストリーミング出力。
predict_streamはgeneratorを返すので、forでイテレーションを回してトークンを出力できます。

for c in loaded_model.predict_stream(
    {"prompt": "大阪の観光名所をひとつ教えて。"},
    params={"max_tokens": 50, "stop": "<|eot_id|>", "system_prompt":"あなたは日本語を話すAIアシスタントです。日本語で必ず回答してください。"},
):
    print(c["content"], end="|")
出力
D|atab|ricks|は|、|Apache| Spark|を|基|盤|とした|データ|エ|ンジ|ン|と|ワ|ーク|ベ|ンチ|ュ|ール|ソ|フト|ウェ|ア|です|。|201|3|年に|創|業|された|D|atab|ricks|は|、|Spark|の|創|始|者|である|Apache| Software| Foundation|の|メ|ンバー|である|Ion| Sto|ica|、|Mate|i| Zah|aria|、|Andrew| Kon|w|inski|、|Patrick| Wend|ell|によって|設|立|され|ました|。

|D|atab|ricks|は|、|Spark|を|基|盤|とした|データ|分析|プラ|ット|フォ|ーム|を|提供|し|、|データ|サイ|エ|ンティ|スト|や|エ|ンジ|ニア|が|、|データ|を|分析|、|処|理|、|スト|レ|ージ|、|ビ|ジ|ネ|ス|イン|テ|リ|ジェ|ンス|に|役|立|つ|ツ|ール|を|提供|しています|。|D|atab|ricks|は|、|Cloud|ベ|ース|の|サービス|として|提供|されて|おり|、|AWS|、|Azure|、|G|CP|などの|Cloud|プラ|ット|フォ|ーム|上|で|実|行|する|ことが|できます|。

|D|atab|ricks|の|主|な|機能|には|、|以下|のような|もの|があり|ます|。

|*| Apache| Spark|を|基|盤|とした|高速|な|データ|分析|
|*| ス|ケ|ーラ|ブル|な|データ|スト|レ|ージ|と|処|理|
|*| ビ|ジュ|アル|な|データ|分析|ツ|ール|
|*| コ|ラ|ボ|レ||ーシ|ョ|ン|機能|による|チーム|ワ|ーク|
|*| セ|キュ|リ|ティ|ー|と|コ|ンプ|ライ|ア|ンス|に対|応|した|設計|

|D|atab|ricks|は|、|データ|サイ|エ|ンティ|スト|、|エ|ンジ|ニア|、|ビ|ジ|ネ|ス|ユ||ーザ|ー|など|、|幅|広|い|範|囲|の|ユ||ーザ|ー|に|役|立|つ|ツ|ール|です|。||

|区切りで生成トークン単位にストリーミング出力できました。

まとめ

predict_streamによってMLFlowにロギングしたLLMを使ってストリーミング出力できるようになりました。
このインターフェースが出来たということは、MLflowからデプロイしたサーバ上で、おそらくストリーミング出力がサポートされるようになるのだと思います。
(つまり、Databricksのモデルサービングもストリーミング対応する?Databricks社の皆様よろしくお願いします!)

MLflowもバージョンを重ねてLLM周りもかなり整備されてきた印象です。
今後もどんどん発展していくことを期待しています。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1