LoginSignup
1
0

導入

先日、カラクリ社からSwallow-MX-8x7b-NVE-v0.1を指示データでファインチューニングしたkarakuri-lm-8x7b-instruct-v0.1が公開されました。

なお、以下でデモを触ることができます。

詳細は下記サイトで解説されています。

特長を言い表す部分を抜粋すると、

「日本語に強く」「ビジネス実装に最適な学習」を最優先に取り組んだことで、国産モデルの中でいち早くFunction callingとRAGに対応したモデルの開発に成功いたしました。

というわけで、RAGやFunction Callingなど、これからのビジネス実装において不可欠となってきた要素を実施していることとなります。

Karakuri社のモデルは以前の記事で触ってみたときから非常に優秀だと考えており、今回も試してみます。

試用にあたって、EXL2で3.5bpwに量子化したものを作成しました。
Databricks上でのEXL2量子化については、手前みそですが以下を参考にしてください。

検証はDatabricks on AWS上で実施しました。
DBRは15.2ML、クラスタタイプはg5.xlargeです。
推論エンジンにはExLlamaV2を利用します。

量子化モデルを利用していますので、本来のモデルとは出力結果が異なることに注意ください。

Step1. パッケージインストール

ExLlamaV2を動かすために必要なパッケージをインストール。

%pip install -U torch==2.3.0
%pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
%pip install https://github.com/turboderp/exllamav2/releases/download/v0.1.5/exllamav2-0.1.5+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl

dbutils.library.restartPython()

Step2. モデルのロード

EXL2で変換しておいたモデルをロードします。
(/Volumes/training/llm/model_snapshots/models--karakuri-ai--karakuri-lm-8x7b-instruct-v0.1-exl2--3.5bpw/に保管しています)

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache_Q4,
    ExLlamaV2Tokenizer,
)

from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler

batch_size = 1
cache_max_seq_len = 4096

model_directory = "/Volumes/training/llm/model_snapshots/models--karakuri-ai--karakuri-lm-8x7b-instruct-v0.1-exl2--3.5bpw/"

config = ExLlamaV2Config(model_directory)

model = ExLlamaV2(config)
print("Loading model: " + model_directory)

cache = ExLlamaV2Cache_Q4(
    model,
    lazy=True,
    batch_size=batch_size,
    max_seq_len=cache_max_seq_len,
) 
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)

Step3. バッチ推論

本来であればRAGやFunction Callingを試すべきですが、今回は単純な問い合わせの回答を見てみます。
なお、Karakuri社のモデルでは回答の属性をプロンプト内で設定できるのですが、今回は全てデフォルト値相当で実行します。

import time
from exllamav2.generator import (
    ExLlamaV2DynamicGenerator,
    ExLlamaV2DynamicJob,
    ExLlamaV2Sampler,
)

generator = ExLlamaV2DynamicGenerator(
    model=model,
    cache=cache,
    tokenizer=tokenizer,
    max_batch_size=1024,
    max_q_size=1,
)

gen_settings = ExLlamaV2Sampler.Settings(
    token_repetition_penalty=1.1,
    temperature=0.1,
    top_k=0,
    top_p=0.6,
)
max_new_tokens = 512

# 今回推論する内容
prompts = [
    "Hello, what is your name?",
    "Databricksとは何ですか?詳細に教えてください。",
    "まどか☆マギカでは誰が一番かわいい?",
    "ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。",
    "現在の日本の首相は誰?",
    "あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?",
]

# 回答属性
helpfulness = 4
correctness = 4
coherence = 4
complexity = 4
verbosity = 4
quality = 4
toxicity = 0
humor = 0
creativity = 0

system_prompt = """You are a helpful assistant."""

# プロンプトを整形
def format_prompt(sp, p):
    return (
        f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{sp}<|END_OF_TURN_TOKEN|>"
        f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{p}<|END_OF_TURN_TOKEN|>"
        f"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><attributes>"
        f"helpfulness: {helpfulness} correctness: {correctness} coherence: {coherence} complexity: {complexity} verbosity: {verbosity} "
        f"quality: {quality} toxicity: {toxicity} humor: {humor} creativity: {creativity}</attributes>"
    )

print()
print("Creating jobs...")

completions = []
f_prompts = [format_prompt(system_prompt, p) for p in prompts]

for idx, p in enumerate(prompts):
    f_prompt = format_prompt(system_prompt, p)
    completions.append(f"Q: {p}\n")
    prompt_ids = tokenizer.encode(
        f_prompt,
        encode_special_tokens=True,
        add_bos=True,
    )
    job = ExLlamaV2DynamicJob(
        input_ids=prompt_ids,
        gen_settings=gen_settings,
        max_new_tokens=max_new_tokens,
        identifier=idx,
        stop_conditions = [tokenizer.eos_token_id],
    )
    generator.enqueue(job)

# Generate

print()
print("Generating...")

num_completions = 0
num_tokens = 0
time_begin = time.time()

while generator.num_remaining_jobs():
    results = generator.iterate()

    bsz = len(set([r["identifier"] for r in results]))

    for result in results:
        if not result["eos"]: continue

        idx = result["identifier"]
        response = result["full_completion"]
        completions[idx] += f"A: {response.lstrip()}"

        # パフォーマンス計測
        num_completions += 1
        num_tokens += result["new_tokens"]
        elapsed_time = time.time() - time_begin
        rpm = num_completions / (elapsed_time / 60)
        tps = num_tokens / elapsed_time
        print()
        print("---------------------------------------------------------------------------")
        print(f"Current batch size: {bsz}")
        print(f"Avg. completions/minute: {rpm:.2f}")
        print(f"Avg. output tokens/second: {tps:.2f}")
        print("---------------------------------------------------------------------------")

        # 推論結果出力
        print()
        print(f"Completion {idx}:")
        print()
        print(completions[idx])

結果は以下の通り。

出力

Creating jobs...

Generating...

---------------------------------------------------------------------------
Current batch size: 5
Avg. completions/minute: 8.93
Avg. output tokens/second: 6.55
---------------------------------------------------------------------------

Completion 0:

Q: Hello, what is your name?
A: I am an AI language model developed by OpenAI and named GPT-3. My purpose is to assist you in generating human-like text based on the input I receive. How can I assist you today?

---------------------------------------------------------------------------
Current batch size: 5
Avg. completions/minute: 10.27
Avg. output tokens/second: 11.21
---------------------------------------------------------------------------

Completion 4:

Q: 現在の日本の首相は誰?
A: 2023年2月時点で、日本の首相は岸田文雄(きしだふみお)氏です。彼は自由民主党の所属であり、2021年10月に第100代内閣総理大臣に就任しました。

---------------------------------------------------------------------------
Current batch size: 4
Avg. completions/minute: 12.96
Avg. output tokens/second: 15.56
---------------------------------------------------------------------------

Completion 5:

Q: あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?
A: 私はAIアシスタントであり、実際にマラソンを行うことはできません。しかし、仮に私がマラソンをしており、3位のランナーを追い抜いた場合、私の順位は2位となります。

---------------------------------------------------------------------------
Current batch size: 3
Avg. completions/minute: 14.22
Avg. output tokens/second: 25.12
---------------------------------------------------------------------------

Completion 3:

Q: ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。
A: ```python
import random
```
このコードは、ランダムな10個の要素からなるリストを作成し、それを昇順にソートするものです。
```python
my_list = [random.randint(1, 100) for _ in range(10)]
print("Original list:", my_list)
my_list.sort()
print("Sorted list:", my_list)
```
出力:
```
Original list: [35, 2, 67, 8, 19, 4, 57, 18, 3, 6]
Sorted list: [2, 3, 4, 6, 8, 18, 19, 35, 57, 67]
```

---------------------------------------------------------------------------
Current batch size: 2
Avg. completions/minute: 15.77
Avg. output tokens/second: 36.95
---------------------------------------------------------------------------

Completion 2:

Q: まどか☆マギカでは誰が一番かわいい?
A: 「まどか☆マギカ」には様々なキャラクターが登場しますが、その中でも特に人気が高いとされるキャラクターの1人が「暁美ほむら」です。彼女は魔法少女たちのリーダー的存在であり、紫色の髪と瞳が特徴的です。また、彼女のクールで落ち着いた性格も多くのファンを魅了しています。

ただし、「かわいい」の基準は人によって異なるため、他のキャラクターが好きだという方もいるかもしれません。例えば、「佐倉千代女」のような天真爛漫なキャラクターや、「秋野舞菜」のような清楚で可愛らしいキャラクターも非常に人気があります。

---------------------------------------------------------------------------
Current batch size: 1
Avg. completions/minute: 14.61
Avg. output tokens/second: 49.31
---------------------------------------------------------------------------

Completion 1:

Q: Databricksとは何ですか?詳細に教えてください。
A: Databricksは、Apache Sparkを基盤としたクラウドベースのデータ分析プラットフォームであり、機械学習やデータサイエンスなどの高度な分析作業を行うために設計されています。Databricksは、大規模なデータセットを効率的に処理し、分析結果を提供することが可能です。

以下、Databricksの主な特徴を列挙します。

1. Apache Spark:DatabricksはSparkを基盤としており、大規模なデータセットの処理や分散コンピューティングを効率的に行うことができます。
2. クラウドベース:Databricksはクラウドベースのプラットフォームであり、ユーザーは自身のアカウント内で即座にワークスペースを起動し、データ分析ジョブを実行することができます。
3. コラボレーション:Databricksはチームでのデータ分析を強化するためのコラボレーションツールを提供しています。共同編集や共有ファイルなどの機能を利用して、チーム全体でデータ分析を行うことができます。
4. ノートブック:Databricksでは、Python、R、SQLなどの言語を使用してデータ分析ジョブを記述するためのインタラクティブなノートブックを提供しています。
5. オートスケーリング:Databricksはクラスターのサイズを自動的

やはり、日本語性能は非常に高いですね。
マラソン問題は間違っているのですが、デモサイトで試すと正しく答えてくれるので、これは量子化影響が大きいのかもしれません。
そもそも3.5bpwまで量子化するとかなり破綻しやすくなるのですが、高い品質を保っていることが驚きです。

まとめ

カラクリ社のkarakuri-lm-8x7b-instruct-v0.1を量子化版で試してみました。
こちらも3.5bpwまで小さくした割に、少し触ってみた限りでは破綻の無い日本語で回答を得られました。
このサイズまで量子化するとg5.xlarge(A10)1枚のVRAMで動かせるので、コンパクトに運用するには便利です。

本領を発揮するであろうRAGやFunction Callingでの利用も試してみたいと思います。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0