導入
ELYZA社がLlama3をベースとした700億パラメータの「Llama-3-ELYZA-JP-70B」と80億パラメータの「Llama-3-ELYZA-JP-8B」を開発し、性能を公開しました。
ベンチマークの詳細などはこちらの方が詳しいです。
70Bパラメータのモデルは旧GPT-4よりも日本語性能が高いということで、かなり優秀なモデルであろうことが伺えます。(実際、デモサイトで触ってみている分には非常によさそうです)
ただ、70Bパラメータモデルは一般公開されていないため、今回は重みが公開されている8BモデルをDatabricks Model Serving上にデプロイした上で試してみます。
Databricks環境はDatabricks on AWS、DBRは15.3MLです。
Step1. パッケージインストール
MLflowとtransformersの最新化、またlangchainをインストールしておきます。
%pip install -U "mlflow>=2.14.1" transformers accelerate tf-keras langchain langchain_community
dbutils.library.restartPython()
Step2. モデルのロード
HuggingFace上から事前に以下のモデルをダウンロードしておき、そちらをロードします。
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_path = "/Volumes/training/llm/model_snapshots/models--elyza--Llama-3-ELYZA-JP-8B/"
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
Step3. MLFlowへのモデル登録
MLFlowのモデルをUnity Catalog下に登録します。
import mlflow
import numpy as np
# Models in UCを使っていない場合には下の行をコメントアウトします
# そして、3レベル名前空間の名称ではなくシンプルにモデル名を指定します
mlflow.set_registry_uri('databricks-uc')
CATALOG = "training"
SCHEMA = "llm"
registered_model_name = f"{CATALOG}.{SCHEMA}.Llama-3-ELYZA-JP-8B"
# サンプル入力の定義
input_example = {
"messages": [
{
"role": "user",
"content": "Hello!"
}
],
"max_tokens": 10,
"stream": False
}
# 新規MLflowランをスタート
with mlflow.start_run():
components = {
"model": model,
"tokenizer": tokenizer,
}
mlflow.transformers.log_model(
transformers_model=components,
task="llm/v1/chat",
artifact_path="model",
registered_model_name=registered_model_name,
input_example=input_example,
example_no_conversion=True,
)
Step4. モデルサービングのエンドポイント登録
Step3で登録したモデルを使って、Databricksのモデルサービングへエンドポイントを作成します。
まず、プロビジョニングされたモデルAPIへの対応状況およびスループットの最小/最大値を取得。
import requests
import json
from mlflow import MlflowClient
# 登録されたMLflowモデルの名前
CATALOG = "training"
SCHEMA = "llm"
registered_model_name = f"{CATALOG}.{SCHEMA}.Llama-3-ELYZA-JP-8B"
model_name = registered_model_name
# MLflowモデルの最新バージョンを取得
client=MlflowClient()
versions = [mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")]
model_version = versions[0]
# 現在のノートブックのコンテキストからAPIエンドポイントとトークンを取得
API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}
response = requests.get(url=f"{API_ROOT}/api/2.0/serving-endpoints/get-model-optimization-info/{model_name}/{model_version}", headers=headers)
print(json.dumps(response.json(), indent=4))
取得した結果を用いて、モデルサービングエンドポイントを作成。
endpoint_name = "Llama-3-ELYZA-JP-8B-endpoint"
# プロビジョンされるスループットの最小値を指定
min_provisioned_throughput = response.json()['throughput_chunk_size']
data = {
"name": endpoint_name,
"config": {
"served_entities": [
{
"entity_name": model_name,
"entity_version": model_version,
"min_provisioned_throughput": 0,
"max_provisioned_throughput": min_provisioned_throughput,
"scale_to_zero_enabled": True,
}
]
},
}
headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}
response = requests.post(url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers)
ここまでで準備完了です。
Step5. 試してみる
準備ができたので、langchainを使っていくつかの指示をバッチ実行してみます。
from langchain_community.chat_models import ChatDatabricks
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
chat = ChatDatabricks(endpoint=endpoint_name, temperature=0.1)
prompts = [
"Hello, what is your name?",
"Databricksとは何ですか?詳細に教えてください。",
"まどか☆マギカでは誰が一番かわいい?",
"ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。",
"現在の日本の首相は誰?",
"あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?ステップバイステップで考えてください。",
]
d_prompts = [{"user_input": prompt} for prompt in prompts]
chat_template = ChatPromptTemplate.from_messages(
[
("human", "{user_input}"),
]
)
chain = chat_template | chat | StrOutputParser()
results = chain.batch(d_prompts, config={"max_concurrency": 16})
# 結果出力
for q, a in zip(d_prompts, results):
print("---------------------------------------")
print("Q: " + q["user_input"])
print("A: " + a.strip())
---------------------------------------
Q: Hello, what is your name?
A: Hello! I'm LLaMA, a large language model trained by a team of researcher at Meta AI.
---------------------------------------
Q: Databricksとは何ですか?詳細に教えてください。
A: Databricksは、Apache Sparkを基盤とした、データエンジニアリング、データサイエンス、データ分析のためのクラウドベースのプラットフォームです。2013年に設立されたDatabricksは、2017年にMicrosoftに買収され、現在はMicrosoftの傘下で運営されています。
Databricksの主な特徴は以下の通りです。
1. Apache Sparkを基盤としている: Databricksは、Apache Sparkを基盤としており、Sparkの強力な処理能力とスケーラビリティを活用することができます。
2. クラウドベース: Databricksは、クラウドベースのプラットフォームであり、ユーザーは、AWS、Azure、GCPなどの主要なクラウドプロバイダー上で、Databricksを使用することができます。
3. Data Engineering、Data Science、Data Analysisのワークロードをサポート: Databricksは、データエンジニアリング、データサイエンス、データ分析のためのワークロードをサポートしています。具体的には、データのインジェスト、データの変換、データのストレージ、データの分析、機械学習モデルのトレーニングなどをサポートしています。
4. SQL、Python、R、Scalaなど多くの言語をサポート: Databricksは、SQL、Python、R、Scalaなど多くの言語をサポートしています。ユーザーは、好みの言語を使用して、Databricks上でデータの処理や分析を行うことができます。
5. Collaborationのための機能を提供: Databricksは、Collaborationのための機能を提供しています。具体的には、ワークブック、ノートブック、ジョブの共有、監査ログの提供などを通じて、チームでのデータ分析や機械学習のプロジェクトをサポートしています。
Databricksの主な用途は以下の通りです。
1. Data Engineering: Databricksは、データエンジニアリングのワークロードをサポートしています。具体的には、データのインジェスト、データの変換、データのストレージ、データのパイプラインの管理などをサポートしています。
2. Data Science: Databricksは、データサイエンスのワークロードをサポートしています。具体的には、データの分析、機械学習モデルのトレーニング、データの可視化などをサポートしています。
3. Data Analysis: Databricksは、データ分析のワークロードをサポートしています。具体的には、データの分析、レポートの作成、ダッシュボードの作成などをサポートしています。
Databricksの料金は、使用したリソースやストレージの量に応じて課金されます。具体的には、ノードの数、ストレージの容量、データの転送量などに応じて料金が発生します。
---------------------------------------
Q: まどか☆マギカでは誰が一番かわいい?
A: まどか☆マギカは人気のあるアニメなので、かわいいキャラクターが多くいますが、誰が一番かわいいかは人によって意見が分かれると思います。
---------------------------------------
Q: ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。
A: 以下は、ランダムな10個の要素からなるリストを作成し、ソートするPythonのコードです。
```python
import random
# ランダムな10個の要素からなるリストを作成
random_list = [random.randint(0, 100) for _ in range(10)]
print("元のリスト:", random_list)
# ソート
random_list.sort()
print("ソート後:", random_list)
```
このコードでは、`random.randint(0, 100)`を使用して、0から100の間のランダムな整数を10個生成し、リストに格納しています。次に、`sort()`メソッドを使用して、リストをソートしています。
---------------------------------------
Q: 現在の日本の首相は誰?
A: 2023年現在の日本の首相は、岸田文雄です。
---------------------------------------
Q: あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?ステップバイステップで考えてください。
A: 現在の順位を考えるために、ステップバイステップで考えます。
1. 現在の順位は3位です。
2. 3位の人の後ろにいる人を抜きました。
3. つまり、3位の人の順位が1つ下がります。
4. 3位の人の順位が下がったため、現在の順位は3位から1つ上がります。
5. したがって、現在の順位は4位から1つ上がり、2位です。
結論として、現在の順位は2位です。
まとめ
Llama-3-ELYZA-JP-8Bを試してみました。
さすがに70Bパラメータのモデルに比べると物足りない部分はありますが、軽量の日本語モデルとしては優秀な性能です。
このサイズでも十分な日本語性能を発揮して利用できるというのは、技術の発展を感じますね。