2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

N番煎じでGoogle Gemma-2 2B JPN-itとRinna社のGemma 2 Baku 2BをDatabricks Mosaic AI Model Servingで試す

Posted at

タイトルが長い。。。

導入

GoogleのGemma-2 2Bを基盤とした日本語モデルが立て続けに公開されました。

ひとつはGoogle社自身がファインチューニングした日本語版 Gemma2 2B。

もうひとつがRinna社がGemma-2 2Bを継続事前学習と指示チューニングで構築したGemma2 Baku 2Bです。

以下のように既に多くの方が試されていますが、後追い上等という感じで試してみたいと思います。

検証はDatabricks on AWS上で行いました。処理構築時のDBRは15.4ML、クラスタはg5.xlarge(GPU)を持ちました。
いつものように無駄に複雑 実用性を考えて、Databricks Mosaic AI Model Servingにエンドポイントを構築してからいくつかのクエリを実行してみます。

エンドポイントの構築

今回は推論エンジンにvLLMを使ったカスタムChatModelを定義し、MLflowを使ってモデルをロギング、Model Serving機能でエンドポイントにデプロイするという流れで準備します。

まずはパッケージを更新。

# torch
%pip install -q -U torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121

# Flash Attention
%pip install -q -U flash-attn --no-build-isolation

%pip install -q -U "vllm==0.6.2"

%pip install -q -U "mlflow-skinny[databricks]>=2.16.2"

dbutils.library.restartPython()

MLflowのカスタムChatModelを定義。

%%writefile "./vllm_chat_model.py"

from typing import List, Union, Iterator, Tuple
import uuid
import logging
import argparse

import mlflow
from mlflow.types.llm import ChatResponse, ChatMessage, ChatParams
from mlflow.models import set_model
import json

from vllm import LLM, SamplingParams

logger = logging.getLogger(__name__)


class VLLMCustomChatModel(mlflow.pyfunc.ChatModel):

    def __init__(self):
        self.model_name = ""
        self.model_config = {}
        self.additional_stop_tokens = []
        self.seed = 123

        self._llm = None

    def load_context(self, context):
        """Load the model from the contexts."""

        model_path = context.artifacts["llm-model"]
        self.model_config = context.model_config.get("model", {})

        logger.info(f"Loading model from {model_path}")

        # Model Configから設定を読み込む
        model_name = self.model_config.get("model_name", "No Name")
        gpu_memory_utilization = self.model_config.get("gpu_memory_utilization", 0.8)
        additional_stop_tokens = self.model_config.get("additional_stop_tokens", [])
        tensor_parallel_size = self.model_config.get("tensor_parallel_size", 1)        
        seed = self.model_config.get("seed", 123)

        # Initialize vLLM
        self._llm = LLM(
            model=model_path,
            gpu_memory_utilization=gpu_memory_utilization,
            tensor_parallel_size=tensor_parallel_size,
        )
        self.model_name = model_name
        self.additional_stop_tokens = additional_stop_tokens
        self.seed = seed

    def predict(
        self,
        context,
        messages: List[mlflow.types.llm.ChatMessage],
        params: mlflow.types.llm.ChatParams,
    ):
        """Predict the response to the given messages."""

        conversation = [{"role": m.role, "content": m.content} for m in messages]
        sampling_params = self._build_settings(
            params, self.additional_stop_tokens, self.seed
        )

        outputs = self._llm.chat(
            conversation,
            sampling_params=sampling_params,
            use_tqdm=False,
        )

        return self._build_response(outputs[0], self.model_name)

    @staticmethod
    def _build_settings(params, additional_stop_tokens, seed):
        """Build the settings for the given parameters."""

        # サンプリングの設定
        settings = SamplingParams()
        settings.temperature = params.temperature or 1.0
        settings.top_k = params.top_k or 50
        settings.top_p = params.top_p or 0.8
        settings.frequency_penalty = params.frequency_penalty or 0.0
        settings.presence_penalty = params.presence_penalty or 0.0
        settings.max_tokens = params.max_tokens or 100
        settings.stop = additional_stop_tokens
        settings.seed = seed

        return settings

    @staticmethod
    def _build_response(output, model_name):
        """Build the response for the given parameters."""

        usage = {
            "prompt_tokens": len(output.prompt_token_ids),
            "completion_tokens": len(output.outputs[0].token_ids),
        }
        usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
        finish_reason = output.outputs[0].finish_reason

        response = {
            "id": output.request_id,
            "model": model_name,
            "choices": [
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": output.outputs[0].text,
                    },
                    "finish_reason": finish_reason,
                }
            ],
            "usage": usage,
        }

        return ChatResponse(**response)


set_model(VLLMCustomChatModel())

MLflowを使って、モデルをロギング。
Huggingfaceのモデルを事前にダウンロードして使いました。
モデルリポジトリはそれぞれ以下です。

gemma-2-2b-jpn-itについては、モデルフォルダに含まれるconfig.jsonのhidden_activation属性をhidden_actに名前変更する(もしくはhidden_activationをコピーしてhidden_act属性を作る)必要がありました。
今回のvLLMバージョンではそうしないとモデルロードができない模様。

import mlflow
import os

# 各種パラメータ(Baku用)
registered_model_name='training.llm.gemma-2-baku-2b-it'
endpoint_name='gemma-2-baku-2b-it-endpoint'
model_path='/Volumes/training/llm/model_snapshots/models--rinna--gemma-2-baku-2b-it/'
model_name='gemma-2-baku-2b-it'
tensor_parallel_size=1
gpu_memory_utilization=0.9
additional_stop_tokens=[]

# MLflowのロギング設定
mlflow.set_registry_uri("databricks-uc")
extra_pip_requirements = [
    "torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121",
    "https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl", # flash-attn
    "vllm==0.6.2",
]
pip_requirements = mlflow.pyfunc.get_default_pip_requirements() + extra_pip_requirements

artifacts = {
    "llm-model": model_path
}
model_config = {
        "model": {
            "model_name": registered_model_name,
            "additional_stop_tokens": additional_stop_tokens,
            "gpu_memory_utilization": gpu_memory_utilization,
            "seed": 123,
            "tensor_parallel_size": tensor_parallel_size,
    }
}
input_example = {
    "messages": [
        {
            "role": "user",
            "content": "日本で一番高い山はなに?",
        }
    ]
}

with mlflow.start_run() as run:
    _ = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model="vllm_chat_model.py",
        artifacts=artifacts,
        model_config=model_config,
        pip_requirements=pip_requirements,
        input_example=input_example,
        await_registration_for=1200,  # モデルサイズが大きいので長めの待ち時間にします
        registered_model_name=registered_model_name,  # 登録モデル名 in Unity Catalog
    )

最後にMosaic AI Model Servingにデプロイ。

import requests
import json
from mlflow import MlflowClient

# デプロイ設定
endpoint_workload_type='GPU_MEDIUM'
endpoint_workload_size='Small'
endpoint_scale_to_zero_enabled='true'

API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 
API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

client=MlflowClient()
versions = [mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")]

data = {
    "name": endpoint_name,
    "config":{
        "served_entities": [
        {
            "entity_name": registered_model_name,
            "entity_version": versions[0],
            "workload_type": endpoint_workload_type,
            "workload_size": endpoint_workload_size,
            "scale_to_zero_enabled": endpoint_scale_to_zero_enabled,
        }]
    },
}

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.post(url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers)

print(json.dumps(response.json(), indent=4))

これでエンドポイントができました。
Gemma2 2B JPNとGemma 2 Baku 2Bともモデルのパス等以外は同じ内容でエンドポイントまで作成できます。

推論の実行

それでは作成したエンドポイントを使って推論します。
ノートブックを作成し、必要なパッケージをインストール。

%pip install -U -qq langchain_core langchain-databricks
%pip install -U "mlflow-skinny[databricks]>=2.16.2"

dbutils.library.restartPython()

そして推論を実行。
変数endpoint_nameには、Model Servingエンドポイントの名前が入っています。

import time
from langchain_databricks import ChatDatabricks

prompts = [
    "Hello, what is your name?",
    "Databricksとは何ですか?詳細に教えてください。",
    "まどか☆マギカでは誰が一番かわいい?",
    "ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。",
    "現在の日本の首相は誰?",
    "あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?ステップバイステップで考えてください。",
]

messages = [[("user", p)] for p in prompts]
chat_model = ChatDatabricks(endpoint=endpoint_name)

start_time = time.time()
outputs = chat_model.batch(messages, temperature=0.0, max_tokens=512)
end_time = time.time()

processing_time = end_time - start_time
total_tokens = 0
for p in zip(prompts, outputs):

    total_tokens += p[1].response_metadata["total_tokens"]

    print("Q:", p[0])
    print("A:", p[1].content)
    print("-"*100)

print("processing_time:", processing_time)
print("total_tokens:", total_tokens)
print("avg. tokens per sec.:", total_tokens / processing_time)

それぞれの出力結果は以下の通り。

gemma-2-2B JPN-it

gemma-2-2B JPN-it 出力
Q: Hello, what is your name?
A: Hello! My name is Gemma. 😊  I'm an open-source AI assistant.  


How can I help you today? 
<end_of_turn>
----------------------------------------------------------------------------------------------------
Q: Databricksとは何ですか?詳細に教えてください。
A: ## Databricks:大規模データ分析と機械学習のための統合プラットフォーム

Databricksは、**大規模データ分析と機械学習を容易にする統合プラットフォーム**です。 

**特徴とメリット**

* **統合された開発環境:** データ分析、機械学習、データ可視化、データ pipelineを一体化できる環境です。
* **Apache Sparkベース:** Apache Sparkの強力な処理能力とオープンソースの柔軟性を活かします。
* **クリーンなコードと効率性:** 標準的なデータ分析の言語(Python、Scala)や機械学習のフレームワーク(MLlib、Spark ML)をサポートし、クリーンなコードと効率的な処理を実現します。
* **コラボレーションとスケーラビリティ:** チームメンバーが効率的にデータ分析や機械学習の作業を行うことができ、大規模なデータセットを処理する際のスケーラビリティも実現します。
* **安全なデータ管理:** 厳格なアクセス制御とセキュリティ機能により、機密情報へのアクセス制限を容易に行えます。
* **豊富なツールとサービス:** Databricksには、データ分析、機械学習、データ可視化に必要な様々なツールとサービスが用意されています。

**具体的な活用例**

* **データ分析:** 顧客分析、マーケティング分析、販売予測など、大量のデータから意思決定をサポート
* **機械学習:** 画像認識、自然言語処理、異常検知など、複雑な機械学習モデルの構築と実行
* **データ可視化:** 分析結果を分かりやすく可視化し、意思決定を支援

**Databricksの構成要素**

* **Databricks Workbench:**  開発者向けの統合開発環境で、データ分析、機械学習、データ可視化の作業を効率的に行えます。
* **Databricks Notebook:**  Python、Scalaなどのコードを実行して、データ分析や機械学習の作業を効率的に行えます。
* **Databricks Delta Lake:**  データのクリーンな管理と高速なデータ分析を実現します。
* **Databricks Clusters:**  高性能な計算環境を提供し、大規模なデータセットを処理できます。
* **Databricks Security:**  アクセス制御とセキュリティ機能により、機密情報へのアクセス制限を容易に行えます。



**メリットまとめ**

* **効率性:** データ分析や機械学習の作業を、統合された
----------------------------------------------------------------------------------------------------
Q: まどか☆マギカでは誰が一番かわいい?
A: まどか☆マギカのキャラクターたちの可愛さについて、個人の好みによって意見は異なります。 

**人気キャラクターとして、特に多くのファンから支持されているのは:**

* **まどか:**  主人公であり、物語の核心を担うキャラクター。明るい性格と、ひたむきな姿は多くのファンに愛される。
* **マギカ:**  まどかと共に物語を展開する、個性的なキャラクター。独特の可愛らしさと魅力が人気。
* **あゆみ:**  まどかの友人であり、いつも明るく元気なキャラクター。可愛らしさと強さが魅力。


 

最終的には、誰が一番かわいいかは、あなたの好みによって異なります。それぞれのキャラクターの魅力は、性格、デザイン、ストーリー展開など、様々な要素によって表現されています。




 


<end_of_turn>
----------------------------------------------------------------------------------------------------
Q: ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。
A: ```python
import random

# ランダムな10個の要素を生成
random_elements = random.sample(range(1, 101), 10)

# リストをソート
random_elements.sort()

# ソートされたリストを表示
print(f"ランダムな要素: {random_elements}") 
```

**解説:**

1. **`import random`**:  Pythonのランダムな要素生成機能をインポートします。
2. **`random.sample(range(1, 101), 10)`**:  1から100までの整数(100個の要素)をランダムに10個選んでリストを作成します。
3. **`random_elements.sort()`**:  リスト `random_elements` を昇順(小さいから大きい)でソートします。
4. **`print(f"ランダムな要素: {random_elements}")`**:  ソートされたリストを表示します。


 


<end_of_turn>
----------------------------------------------------------------------------------------------------
Q: 現在の日本の首相は誰?
A: 現在の日本の首相は **岸田文雄** です。 



<end_of_turn>
----------------------------------------------------------------------------------------------------
Q: あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?ステップバイステップで考えてください。
A: 今の順位は3位です。 

 

 


<end_of_turn>
----------------------------------------------------------------------------------------------------
processing_time: 14.166987895965576
total_tokens: 1111
avg. tokens per sec.: 78.42175119782425

gemma-2 BAKU 2B-it

gemma-2 BAKU 2B-it 出力
Q: Hello, what is your name?
A: Hello! I am Gemma, an open-weights AI assistant.  How can I help you today? 😊  

----------------------------------------------------------------------------------------------------
Q: Databricksとは何ですか?詳細に教えてください。
A:  Databricks は、オープンソースの機械学習およびデータサイエンスプラットフォームです。Apache Spark をベースにしており、データウェアハウス、機械学習、データビジュアライゼーションなどのさまざまなデータ処理タスクを統合し、ユースケースに合わせたソリューションを提供します。

簡単に説明すると、Databricksは次のような機能を備えた統合プラットフォームと言えるでしょう。

**主な特徴:**

1. **Sparkベース:** Databricksは、Apache Spark を使用して、大規模なデータセットを高速に処理する並列データ処理フレームワークです。
2. **データ分析、機械学習、可視化:** これらの機能を包括的に提供し、ユーザーはデータのクレンジング、クリーニング、モデリング、分析、そして結果の可視化まで一連のプロセスを実行できます。
3. **オープンソース:** Databricks はオープンソースであり、コミュニティからのフィードバックや開発貢献を受け、自由に使用、拡張、改変することができます。
4. **コンテナ技術:**  Databricks はコンテナ技術である Docker や Kubernetes と統合しており、柔軟なワークフローとスケーラビリティを提供します。
5. **Collaboration & Security:**  ユーザーは Databricks 内で共同作業を行い、セキュリティを確保しながらデータ分析を安全に実行できます。

**主なユースケース:**

* **データ分析:**  大量データの可視化、データの探索、統計分析、パターン検出、機械学習モデルの構築
* **機械学習:**  画像認識、自然言語処理、感情分析、予測分析、時系列分析など、様々なユースケースに対応する機械学習モデルを構築
* **データウェアハウス:**  データのクレンジング、変換、統合、可視化、および管理
* **ビジネスアナリティクス:**  ビジネスの洞察を得るためのリアルタイムデータ分析


**メリット:**

* **迅速な開発と展開:** Databricks は、包括的なツールセットを提供し、開発時間を短縮し、迅速なプロトタイプ作成、テスト、デプロイメントを可能にします。
* **柔軟性:**  コンテナ技術との統合により、さまざまなインフラストラクチャとスケーリングオプションを選択できます。
* **コラボレーション:**  共同作業機能により、チームメンバー間のコミュニケーションを促進し、効率的なデータ分析プロセスを支援します。
* **オープンソース:**  オープンソースである
----------------------------------------------------------------------------------------------------
Q: まどか☆マギカでは誰が一番かわいい?
A: まどか☆マギカという作品において、誰が一番かわいいかは、とても主観的な意見であり、人によって異なると思います。

人気や印象、そして美少女としての要素を総合的に判断すると、多くのファンから支持を得ているのは「マミ・ルイス」でしょう。

彼女は、

* **魔法の才能****強い意志** を持ち、一見クールで強そうな印象ですが、実は内面は優しく、生徒たちを支えようとする情熱的な部分があります。
* **整った容姿****凛とした立ち姿** は、多くのファンを魅了し、その魅力的なルックスは忘れられません。
* **ユーモア****個性的な性格** が、彼女をより魅力的にしています。

これらの要素が複合的に絡み合って、マミは多くのファンから「かわいい」と評されています。

ただし、他の主要キャラクターもそれぞれ魅力があり、可愛さを感じる点があります。

例えば、「美琴」は可愛らしい外見と天真爛漫な性格、そして魔法の能力が魅力的です。また、「杏子」はクールな性格ながらも、実はとても情熱的で心優しい面を見せることもあります。

最終的には、あなたがどのキャラクターを一番「かわいい」と感じるかによって、まどか☆マギカにおけるそれぞれのキャラクターの魅力は大きく変わってくるでしょう。
----------------------------------------------------------------------------------------------------
Q: ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。
A: ```python
import random

# ランダムな10個の要素
random_list = [
    "apple", "banana", "orange", "grape", "watermelon",
    "pineapple", "mango", "strawberry", "apricot", "kiwi"
]

# リストをランダムにソート
random.shuffle(random_list)

# ソートされたリストを表示
print(random_list)
```

このコードでは、まず `import random` でPythonのランダム関数モジュールをインポートします。

次に、リストを作成するために `random_list` という名前のリスト変数を作成し、さまざまな果物名が入っています。

そして、 `random.shuffle(random_list)` の行で `shuffle()` メソッドを使って `random_list` の要素をランダムに並べ替えます。

最後に、 `print(random_list)` でソートされた `random_list` を表示します。

実行すると、毎回異なる順番の果物が表示されるはずです。


より詳細に説明すると、

* `random.shuffle(list)`: このメソッドは、与えられたリストをランダムに再配置します。 


----------------------------------------------------------------------------------------------------
Q: 現在の日本の首相は誰?
A: 現在の日本の首相は、** Fumio Kishida(岸田文雄)** 氏です。 


2021年10月に発足した第1次岸田内閣が現在、日本の政権を担っています。 

----------------------------------------------------------------------------------------------------
Q: あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?ステップバイステップで考えてください。
A: マラソンをしていると、3位の人を抜いたとのこと、素晴らしいですね!

あなたの現在の順位は、**3位**です。

以下、ステップバイステップで説明します。

1. あなたは、スタート時に何位でしたか?
2. 他のランナーと競争しながら走っていました。
3. 3位の人を追い越したため、あなたは前のランナーとの間を縮め、現在3位にいます。

したがって、現在あなたが走っている状況から、あなたの順位は3位であることがわかります。
----------------------------------------------------------------------------------------------------
processing_time: 18.212523460388184
total_tokens: 1387
avg. tokens per sec.: 76.15638782932486

gemma-2-2B JPN-itは追加のストップトークン設定が必要だったかな。

gemma-2-2B JPN-itは少し端的な回答の傾向があり、gemma-2 BAKU 2B-itの方は少し柔らかい日本語表現で回答がくる傾向にあるようです。
また、マラソン問題を解けているところから、どちらも推論性能高そうですね。
わずか2Bのパラメータなのに。。。すごい。。。

まとめ

最近出たGemma-2 2Bモデルの日本語ファインチューニング/継続事前学習モデルを試してみました。
単純に試すだけだったらModel Servingでエンドポイント作る必要はないのですが、一度エンドポイント作っておくと再利用しやすく(不要であれば停止しておける)便利です。

しかし、ローカルLLMはスケーリング則に則ってどんどん大型化するかと思いきや、小型高性能なモデル(いわゆるSLM)も注目すべき領域となってきていますね。
トップ性能で言えば巨大なモデルの方が優位ではありますが、複合AIシステムの構築で性能を出すことも考えるとこういった優秀な小型モデルの方が全体的な利便性は高いんだろうなと思います。

まだまだ面白い状況が続くなあ。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?