LoginSignup
0
0

N番煎じでMistral.aiのMistral-7B-Instruct-v0.3を試す on Databricks Model Serving

Posted at

新しいLLMは連続で出る、これテストに出ます。

導入

Mistral.aiより、Mistral-7Bの最新モデルv0.3が公開されました。

v0.2と比べて、以下のような変化点があります。

  • 語彙数を32768に拡張
  • v3 トークナイザーをサポート
  • function callingをサポート

特に日本語性能がどのように変わったのか確認したく、Databricks上で試してみました。

検証環境はDatabricks on AWS、DBRは15.1MLです。
今回はプロビジョニングされたスループットの基盤モデルとしてモデルサービングに登録して確認してみます。

Step1. パッケージインストール

MLflowとtransformersの最新化、またlangchainをインストールしておきます。

%pip install -U "mlflow>=2.13.0" transformers accelerate langchain langchain_community

dbutils.library.restartPython()

Step2. モデルのロード

HuggingFace上から事前に以下のモデルをダウンロードしておき、そちらをロードします。

ダウンロードするためには、Huggingface上で連絡先情報を共有することに同意する必要があります。

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_path = "/Volumes/training/llm/model_snapshots/models--mistralai--Mistral-7B-Instruct-v0.3/"

model = AutoModelForCausalLM.from_pretrained(
    model_path, 
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
)

Step3. MLFlowへのモデル登録

MLFlowのモデルをUnity Catalog下に登録します。

import mlflow
import numpy as np

# Models in UCを使っていない場合には下の行をコメントアウトします 
# そして、3レベル名前空間の名称ではなくシンプルにモデル名を指定します
mlflow.set_registry_uri('databricks-uc')
CATALOG = "training"
SCHEMA = "llm"
registered_model_name = f"{CATALOG}.{SCHEMA}.mistral-7b-instruct-v03"

# サンプル入力の定義
input_example = {
  "messages": [
    {
      "role": "user",
      "content": "Hello!"
    }
  ],
  "max_tokens": 10,
  "stream": False
}

# 新規MLflowランをスタート
with mlflow.start_run():
    components = {
        "model": model,
        "tokenizer": tokenizer,
    }
    mlflow.transformers.log_model(
        transformers_model=components,
        task="llm/v1/chat",
        artifact_path="model",
        registered_model_name=registered_model_name,
        input_example=input_example,
        example_no_conversion=True,
    )

Step4. モデルサービングのエンドポイント登録

Step3で登録したモデルを使って、Databricksのモデルサービングへエンドポイントを作成します。

まず、「プロビジョニングされたスループットの基盤モデル」として利用可能かどうか確認。

import requests
import json
from mlflow import MlflowClient

# 登録されたMLflowモデルの名前
CATALOG = "training"
SCHEMA = "llm"
registered_model_name = f"{CATALOG}.{SCHEMA}.mistral-7b-instruct-v03"

model_name = registered_model_name

# MLflowモデルの最新バージョンを取得
client=MlflowClient()
versions = [mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")]
model_version = versions[0]

# 現在のノートブックのコンテキストからAPIエンドポイントとトークンを取得
API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 
API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.get(url=f"{API_ROOT}/api/2.0/serving-endpoints/get-model-optimization-info/{model_name}/{model_version}", headers=headers)

print(json.dumps(response.json(), indent=4))
出力
{
    "optimizable": true,
    "model_type": "mistral",
    "throughput_chunk_size": 970,
    "dbus": 24
}

optimizableがtrueですので、v0.3でもちゃんと対応していますね。

では、エンドポイントへ登録。

endpoint_name = "mistral-7b-instruct-v03-endpoint"

# プロビジョンされるスループットの最小値を指定 
min_provisioned_throughput = response.json()['throughput_chunk_size']

data = {
    "name": endpoint_name,
    "config": {
        "served_entities": [
            {
                "entity_name": model_name,
                "entity_version": model_version,
                "min_provisioned_throughput": 0,
                "max_provisioned_throughput": min_provisioned_throughput,
                "scale_to_zero_enabled": True,
            }
        ]
    },
}

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.post(url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers)

問題なければ、10分程度でサービングの準備が完了します。

image.png

Step5. 試してみる

準備ができたので、langchainを使っていくつかの指示をバッチ実行してみます。

from langchain_community.chat_models import ChatDatabricks
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

chat = ChatDatabricks(endpoint=endpoint_name, temperature=0.1)

prompts = [
    "Hello, what is your name?",
    "Databricksとは何ですか?詳細に教えてください。",
    "まどか☆マギカでは誰が一番かわいい?",
    "ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。",
    "現在の日本の首相は誰?",
    "あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?",
]
d_prompts = [{"user_input": prompt} for prompt in prompts]

chat_template = ChatPromptTemplate.from_messages(
    [
        ("human", "{user_input}"),
    ]
)

chain = chat_template | chat | StrOutputParser()

results = chain.batch(d_prompts, config={"max_concurrency": 16})

# 結果出力
for q, a in zip(d_prompts, results):
    print("---------------------------------------")
    print("Q: " + q["user_input"])
    print("A: " + a.strip())
出力
---------------------------------------
Q: Hello, what is your name?
A: Hello! I don't have a personal name, but you can call me Assistant. How can I assist you today?
---------------------------------------
Q: Databricksとは何ですか?詳細に教えてください。
A: Databricksは、Apache Sparkのベースとなる大規模なデータ処理エンジンを使用して、データエンジニアリング、マシン学習、ビジネスインテリジェンス、およびデータサイエンスのワークロードを実行するクラウドベースのプラットフォームです。 Databricksは、Amazon Web Services (AWS)、Microsoft Azure、およびGoogle Cloud Platform (GCP)上で利用可能です。

Databricksは、Apache Sparkを使用して、データをストリーミング、処理、分析、およびビジュアライズできます。 Sparkは、大規模なデータセットを処理するために使用される、高速、スケーラブルなエンジンです。 Databricksは、Sparkの高速化、管理、および操作を容易にするために、独自のツールと機能を提供します。

Databricksの主な特徴は、次のとおりです。

1. 自動スケーリング:Databricksは、自動的にクラスターをスケーリングし、ワークロードに応じてリソースを割り当てます。
2. セキュリティ:Databricksは、データの保護を強化するために、役割ベースのアクセス制御、暗号化、およびデータの非表示化をサポートします。
3. コラボレーション:Databricksは、複数のユーザーが同時にワークスペースにアクセスでき、コードを共有し、リポジトリを管理できます。
4. 統合:Databricksは、他のクラウドサービスとの統合をサポートし、データを移動、変換、および分析するために使用できます。
5. ビジュアライズ:Databricksは、データをビジュアライズするために、独自のビジュアライズツールを提供します。

Databricksは、データエンジニアリング、データサイエンス、およびビジネスインテリジェンスの各領域で、大規模なデータセットを処理するために使用される重要なツールです。
---------------------------------------
Q: まどか☆マギカでは誰が一番かわいい?
A: マドカ☆マギカのキャラクターについて、かわいいと感じる人は異なることが多いです。しかし、多くの人にとって、エドワード・エルドラード(エド)はかわいいキャラクターと感じられます。エドは、幼い時から、魔法使いとしての負荷に負っていることから、かつ、幼い体で大人の責任を負っていることから、かわいいイメージを持っています。また、彼の優しい、温かい、堅絶な性格も、かわいいと感じられます。しかし、個人的には、かわいいキャラクターと感じる人は異なることが多いです。
---------------------------------------
Q: ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。
A: 以下は、Pythonでランダムな10個の整数を含むリストを作成し、それをソートして表示するコードです。

```python
import random

# ランダムな10個の整数を含むリストを作成
random_list = [random.randint(1, 100) for _ in range(10)]
print("Random list before sorting:", random_list)

# リストをソート
sorted_list = sorted(random_list)
print("Sorted list:", sorted_list)
```

このコードは、`random.randint(1, 100)` 関数を使用して、1から100の整数をランダムに10個選び、それらをリストに格納します。その後、`sorted()` 関数を使用して、リストを昇順にソートします。最後に、ソート前とソート後のリストを表示します。
---------------------------------------
Q: 現在の日本の首相は誰?
A: 現在、日本の首相(大臣)は徳島県出身の堀内文雄さまです。2021年10月4日に、前首相菅田経済大臣に任命され、10月4日から10月8日に、国会議会での議決を待っています。
---------------------------------------
Q: あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?
A: あなたは3位の人を抜きましたので、あなたは4位になります。

v0.2までと比べて、特別な指定も無しに綺麗な日本語で回答が生成されました。
トークナイザの改善のためなのか、学習データの恩恵なのか、かなり改善されています。

日本に関する知識は少ないようで、明らかにハルシネーションを起こしていますし、マラソン問題を見るに論理思考も少し気になるところではあるのですが、7Bというサイズを考えてるとかなり優秀です。
試していませんが、日本語ファインチューニングしなくてもRAGの生成部に使うLLMとしては有効な気がします。

まとめ

Phi-3もそうでしたが、小型で優秀なマルチリンガルモデルが一気に出てきていますね。
性能のよい大型モデルを使うのか、取り回しの良い小型モデルを使うのか、選択肢が増えてきていて嬉しい悩みです。
プロプライエタリモデルも含めると、どのシチュエーションで何を使うのかというノウハウが非常に大事になってきている気がしています。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0