導入
少し時間が経ちましたが、Google社が第2世代のGemma、Gemma2をリリースしました。
9Bと27Bという2種のパラメータサイズのモデルが公開されています。
ベンチマークでも高い性能を示しています。
というか、9Bがサイズの割にいい性能を示してますね。
というわけで、気になったので軽く試してみました。今回は指示チューニングされた9Bモデルを利用します。
検証はDatabricks on AWS上で実施しました。
DBRは15.3ML、クラスタタイプはg5.xlargeです。
推論エンジンにはExLlamaV2を利用します。
量子化モデルを利用していますので、本来のモデルとは出力結果が異なることに注意ください。
Step1. パッケージインストール
Flash-AttentionとExLlamaV2をインストール。
%pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
%pip install https://github.com/turboderp/exllamav2/releases/download/v0.1.7/exllamav2-0.1.7+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl
dbutils.library.restartPython()
Step2. モデルのロード
以下のモデルを事前にダウンロードしておき、そこからロードします。
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache_Q4,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler
import time
from exllamav2.generator import (
ExLlamaV2DynamicGenerator,
ExLlamaV2DynamicJob,
ExLlamaV2Sampler,
)
batch_size = 1
cache_max_seq_len = 8192
model_directory = "/Volumes/training/llm/model_snapshots/models--turboderp--gemma-2-9b-it-exl2--8.0bpw/"
config = ExLlamaV2Config(model_directory)
config.arch_compat_overrides()
model = ExLlamaV2(config)
print("Loading model: " + model_directory)
cache = ExLlamaV2Cache_Q4(
model,
lazy=True,
batch_size=batch_size,
max_seq_len=cache_max_seq_len,
)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2DynamicGenerator(
model=model,
cache=cache,
tokenizer=tokenizer,
max_batch_size=1024,
max_q_size=1,
)
gen_settings = ExLlamaV2Sampler.Settings(
token_repetition_penalty=1.1,
temperature=0.1,
top_k=0,
top_p=0.6,
)
max_new_tokens = 512
Step3. バッチ推論
こちらの記事と同様の問をなげて回答を取得します。
# 今回推論する内容
prompts = [
"Hello, what is your name?",
"Databricksとは何ですか?詳細に教えてください。",
"まどか☆マギカでは誰が一番かわいい?",
"ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。",
"現在の日本の首相は誰?",
"あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?ステップバイステップで考えてください。",
]
system_prompt = """あなたは親切なAIアシスタントです。"""
# プロンプトを整形
def format_prompt(sp, p):
return (
"<start_of_turn>user\n"
f"{sp}<end_of_turn>\n"
"<start_of_turn>model\nOK!<end_of_turn>\n"
"<start_of_turn>user\n"
f"{p}<end_of_turn>\n"
"<start_of_turn>model\n"
)
print()
print("Creating jobs...")
completions = []
f_prompts = [format_prompt(system_prompt, p) for p in prompts]
for idx, p in enumerate(prompts):
f_prompt = format_prompt(system_prompt, p)
completions.append(f"Q: {p}\n")
prompt_ids = tokenizer.encode(
f_prompt,
encode_special_tokens=True,
add_bos=True,
)
job = ExLlamaV2DynamicJob(
input_ids=prompt_ids,
gen_settings=gen_settings,
max_new_tokens=max_new_tokens,
identifier=idx,
stop_conditions = [tokenizer.eos_token_id],
)
generator.enqueue(job)
# Generate
print()
print("Generating...")
num_completions = 0
num_tokens = 0
time_begin = time.time()
while generator.num_remaining_jobs():
results = generator.iterate()
bsz = len(set([r["identifier"] for r in results]))
for result in results:
if not result["eos"]: continue
idx = result["identifier"]
response = result["full_completion"]
completions[idx] += f"A: {response.lstrip()}"
# パフォーマンス計測
num_completions += 1
num_tokens += result["new_tokens"]
elapsed_time = time.time() - time_begin
rpm = num_completions / (elapsed_time / 60)
tps = num_tokens / elapsed_time
print()
print("---------------------------------------------------------------------------")
print(f"Current batch size: {bsz}")
print(f"Avg. completions/minute: {rpm:.2f}")
print(f"Avg. output tokens/second: {tps:.2f}")
print("---------------------------------------------------------------------------")
# 推論結果出力
print()
print(f"Completion {idx}:")
print()
print(completions[idx])
以下、回答結果の出力です。(処理の関係上、質問リストの順番通りになっていません)
Creating jobs...
Generating...
---------------------------------------------------------------------------
Current batch size: 6
Avg. completions/minute: 34.22
Avg. output tokens/second: 18.82
---------------------------------------------------------------------------
Completion 0:
Q: Hello, what is your name?
A: I am Gemma, a large language model. I don't have a personal name like humans do. 😊
How can I help you today?<end_of_turn>
---------------------------------------------------------------------------
Current batch size: 5
Avg. completions/minute: 57.60
Avg. output tokens/second: 36.48
---------------------------------------------------------------------------
Completion 4:
Q: 現在の日本の首相は誰?
A: 現在の日本の首相は **岸田文雄** です。 2021年10月4日から在任しています。
何か他に質問があれば、お気軽にどうぞ!😊
<end_of_turn>
---------------------------------------------------------------------------
Current batch size: 4
Avg. completions/minute: 83.06
Avg. output tokens/second: 56.30
---------------------------------------------------------------------------
Completion 5:
Q: あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?ステップバイステップで考えてください。
A: 1. あなたは3位の人を追い越したので、その人の前の順位になります。
2. 3位の人の前にいるのは2位です。
答え: 今度は2位です! 🎉
<end_of_turn>
---------------------------------------------------------------------------
Current batch size: 3
Avg. completions/minute: 39.39
Avg. output tokens/second: 44.80
---------------------------------------------------------------------------
Completion 2:
Q: まどか☆マギカでは誰が一番かわいい?
A: 「一番かわいい」は人それぞれ好みが違いますよね!
まどか☆マギカには、それぞれの魅力を持つ可愛い女の子がたくさんいますね。
* **鹿目まどか**: 明るく元気で優しい性格の主人公。笑顔がとってもキュート!
* **美樹さやか**: クールビューティーで少しツンデレなところがたまらない!
* **暁美ほむら**: 強い意志と謎めいた雰囲気を持つ、ちょっと大人っぽい魅力があります。
* **巴マミ**: 美しい容姿と優雅な立ち振る舞いが印象的です。
どのキャラクターが一番かわいいかは、あなた次第です! ぜひ、ご自身の好きなポイントを見つけてみてくださいね😊✨
<end_of_turn>
---------------------------------------------------------------------------
Current batch size: 2
Avg. completions/minute: 20.57
Avg. output tokens/second: 41.75
---------------------------------------------------------------------------
Completion 3:
Q: ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。
A: ```python
import random
# ランダムな10個の要素からなるリストを作成
my_list = random.sample(range(1, 51), 10) # 1から50までの数字の中から10個をランダムに選択
print("元のリスト:", my_list)
# リストを昇順にソート
my_list.sort()
print("ソートされたリスト:", my_list)
```
**解説:**
1. **`import random`**: `random`モジュールをインポートし、乱数生成機能を使用できるようにします。
2. **`my_list = random.sample(range(1, 51), 10)`**:
- `range(1, 51)`: 1から50までの整数の範囲を生成します。
- `random.sample(..., 10)`: この範囲から重複なく10個の要素をランダムに選び、`my_list`というリストに格納します。
3. **`print("元のリスト:", my_list)`**: 元のリストを出力します。
4. **`my_list.sort()`**: `my_list`内の要素を昇順にソートします。この操作はリスト自体を変更します。
5. **`print("ソートされたリスト:", my_list)`**: ソートされたリストを出力します。
このコードを実行すると、ランダムな10個の数字を含むリストが生成され、その後昇順にソートされます。<end_of_turn>
---------------------------------------------------------------------------
Current batch size: 1
Avg. completions/minute: 18.10
Avg. output tokens/second: 56.37
---------------------------------------------------------------------------
Completion 1:
Q: Databricksとは何ですか?詳細に教えてください。
A: ## Databricks: データ分析と機械学習のためのプラットフォーム
Databricksは、データエンジニア、データサイエンティスト、機械学習エンジニアが協力してデータを処理し、洞察を導き出し、モデルを構築するためのクラウドベースの統合開発環境(IDE)を提供するプラットフォームです。
**主な特徴:**
* **Apache Sparkに基づく:** Databricksは、大規模データ処理に特化したオープンソースのフレームワークであるApache Sparkを基盤としています。Sparkは高速でスケーラブルであり、様々なデータ処理タスクに対応できます。
* **Unified Platform:** Databricksは、データエンジニアリング、データ分析、機械学習を一つのプラットフォーム上で統合しています。これにより、チームメンバーは共通のツールと環境を使用でき、コラボレーションが容易になります。
* **Collaborative Workspace:** Databricksは、チームが共同でコードを記述、実行、デバッグできるインタラクティブなノートブック環境を提供します。この機能は、データ分析や機械学習プロジェクトにおけるコミュニケーションと効率性を向上させます。
* **Machine Learning Capabilities:** Databricksには、機械学習モデルのトレーニング、評価、デプロイに必要なツールが組み込まれています。AutoML機能も提供されており、初心者でも簡単に機械学習モデルを作成することができます。
* **Scalability and Reliability:** Databricksは、クラウド上のスケーラブルなインフラストラクチャを利用しており、大規模なデータセットにも対応できます。また、高可用性とデータ保護機能も備えています。
* **Integration with Other Tools:** Databricksは、他のデータ分析ツールやサービスとの連携も可能です。例えば、Amazon S3、Azure Blob Storage、Google Cloud Storageなどのストレージサービス、Tableau、Power BIなどのビジュアライゼーションツールなどとの連携が可能です。
**Databricksの利点:**
* **迅速な開発サイクル:** Sparkの高速処理能力と統合されたツールにより、データ分析と機械学習の開発サイクルが短縮されます。
* **高い生産性:** コラボレーション機能と自動化機能により、チームの生産性が向上します。
* **コスト削減:** クラウドベースのプラットフォームを採用することで、ハードウェア投資や管理コストを削減できます。
* **柔軟性:** Databricksは、さまざまなデータソースやユースケースに対応できる柔軟性の高いプラットフォームです。
**まとめ:**
Databricksは、データ
<end_of_turn>
文字列が最後に生成されてしまっていますが、回答自体は非常に適正な内容が返ってきていると思います。絵文字が多いのはGemmaのクセっぽい感じですね。プロンプトの与え方で変わると思いますが。
このサイズでの性能としては非常に高いんじゃないかと思いました。RAGと組み合わせた場合の能力も確認してみたいと思います。
まとめ
Google社のGemma2 9B-it(EXL2量子化版)を試してみました。
一応商用利用可能なライセンスですし、パラメータサイズと性能を見ても使いやすいモデルだと思います。
日本語用途としては、これやCALM-3の利用が最近だと使いやすそうですね。