1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DBRX InstructをExLlamav2を利用して推論する on Databricks

Posted at

導入

先日、Databricks社が新たなオープンLLMであるDBRXを公開しました。
132Bというパラメータサイズの大型LLMであり、MixtralGrok-1同様、Mixture of Expertを採用したアーキテクチャとなっています。(そのため、推論時のアクティブなパラメータ数は36Bの模様)
公式のベンチマークでは、GPT-3.5に匹敵する性能となっています。

ライセンスは独自ライセンスであり、出力結果を別のモデルの学習に利用できないなど制約はありますが、自由度の高い形態になっています。

2024/4/1現在、HuggingFaceのモデルトレンドでは最上位にきており、人気の高さが伺えます。

その他、技術的な詳細は以下の記事等をご確認ください。


オープンLLMとはいえ、132BパラメータのLLMをまともに動かすのはそれなりのGPUが必要です。
しかし、先日ExLlamaV2がDBRXにも対応しましたので、量子化したモデルを利用して、少しだけ省エネで動かしてみます。

量子化したモデルを利用するため、フルサイズのモデルに比べて精度は悪化していると思われます。
結果は参考程度に確認ください。

実行環境はDatabricks on AWS、DBRは14.3MLです。
推論の実行はg5.12xlargeを利用しました。省エネと言えどもトータルで80GBほどのVRAMが必要です。

Step1. モデルのダウンロード

以下のHuggingFaceリポジトリとして、EXL2量子化済みのモデルが公開されています。

今回は、3.40bpwで量子化したモデルを利用することにします。
(より小さいbits per weightモデルを利用することでVRAMの使用量を減らすことができますが、性能劣化がかなり大きそうなので、このサイズにしました)
以下のコードを実行し、HuggingFaceからモデルをダウンロード、その後Unity Catalog Volumesに保管します。


from typing import Optional

def download_model(model_id:str, revision:Optional[str]=None):
    import os
    from huggingface_hub import snapshot_download

    UC_VOLUME = "/Volumes/training/llm/model_snapshots"

    rev_dir = ("--" + revision) if revision else ""
    local_dir = f"/tmp/{model_id}{rev_dir}"
    uc_dir = f"/models--{model_id.replace('/', '--')}"
    
    snapshot_location = snapshot_download(
        repo_id=model_id,
        revision=revision,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
        force_download=False,
    )

    dbutils.fs.cp(f"file:{local_dir}", f"{UC_VOLUME}{uc_dir}{rev_dir}", recurse=True)

model_id = "turboderp/dbrx-instruct-exl2"
revision = "3.4bpw"
download_model(model_id, revision)

Step2. パッケージのインストール

ExLlamaV2含め、必要なパッケージをインストールします。ExLlamaV2は0.0.17以降をインストールしてください。
※ pytorch 2.2.0以降が必要となります。ファイルサイズが大きいため、事前にストレージに保管しておくことをお勧めします。

%pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu118
%pip install ninja
%pip install -U flash-attn --no-build-isolation

%pip install https://github.com/turboderp/exllamav2/releases/download/v0.0.17/exllamav2-0.0.17+cu118-cp310-cp310-linux_x86_64.whl

dbutils.library.restartPython()

Step3. モデルのロード

ExLlamaV2を使ってStep1でダウンロードしたモデルをロードします。
サンプリング設定はかなり適当な値を設定しています。

from exllamav2 import(
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache_Q4,
    ExLlamaV2Tokenizer,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2Sampler
)

batch_size = 1

model_directory = "/Volumes/training/llm/model_snapshots/models--turboderp--dbrx-instruct-exl2--3.4bpw/"

config = ExLlamaV2Config(model_directory)
config.max_output_len = 1  # 一度に1つのトークンだけ生成するため、VRAMを bsz * max_seq_len * vocab_size の logits に割り当てる必要はありません
config.max_batch_size = batch_size  # モデルインスタンスは、最大バッチサイズに合わせた一時バッファを割り当てる必要があります

model = ExLlamaV2(config)
print("Loading model: " + model_directory)

cache = ExLlamaV2Cache_Q4(model, lazy = True, batch_size = batch_size)  # キャッシュはバッチサイズに合わせて割り当てる必要があります
model.load_autosplit(cache)

tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

# サンプリングの設定
settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.0
settings.top_k = 50
settings.top_p = 0.9
settings.token_repetition_penalty = 1.05

max_new_tokens = 512

Step4. 推論

では、ロードしたモデルを利用して推論を実施します。
今回は以下のコードに含まれる推論内容をまとめて実行し、推論結果を確認してみます。

# 今回推論する内容
prompts = [
    "Hello, what is your name?",
    "Databricksとは何ですか?詳細に教えてください。",
    "まどか☆マギカでは誰が一番かわいい?",
    "ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。",
    "現在の日本の首相は誰?",
    "あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?",
]

# パディングを最小化するために文字列サイズでソート
s_prompts = sorted(prompts, key=len)

# プロンプトを整形
def format_prompt(sp, p):
    return f"<|im_start|>system\n{sp}<|im_end|>\n<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"


system_prompt = "You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER\\'S QUERY.You must reply in Japanese only."

f_prompts = [format_prompt(system_prompt, p) for p in s_prompts]

# 生計済みプロンプトをバッチに分割
batches = [f_prompts[i : i + batch_size] for i in range(0, len(prompts), batch_size)]

collected_outputs = []
for b, batch in enumerate(batches):

    print(f"Batch {b + 1} of {len(batches)}...")

    outputs = generator.generate_simple(
        batch, settings, max_new_tokens, seed=1234, add_bos=True
    )

    trimmed_outputs = [o.split("assistant\n")[1] for o in outputs] # 簡易分割
    collected_outputs += trimmed_outputs

# 結果出力
for q, a in zip(s_prompts, collected_outputs):
    print("---------------------------------------")
    print("Q: " + q)
    print("A: " + a.strip())
出力
Batch 1 of 6...
Batch 2 of 6...
Batch 3 of 6...
Batch 4 of 6...
Batch 5 of 6...
Batch 6 of 6...
---------------------------------------
Q: 現在の日本の首相は誰?
A: 現在の日本の首相は、岸田文雄です。

(Note: The current Prime Minister of Japan is Fumio Kishida.)
---------------------------------------
Q: まどか☆マギカでは誰が一番かわいい?
A: まどか☆マギカでは、まどかさんが一番かわいいです。
---------------------------------------
Q: Hello, what is your name?
A: Hello! I'm an AI language model and I don't have a personal name. I'm here to help you with any questions or tasks you have. How can I assist you today?
---------------------------------------
Q: Databricksとは何ですか?詳細に教えてください。
A: (ダビド・データブリックスは、統一されたアナリティクスプラットフォームです。データエンジニアリング、データサイエンス、ビジネスインテリジェンスなど、さまざまな分野の専門家が、データ駆動型の意思決定を行うことができます。データ処理、機械学習、可視化など、多様なデータ関連のタスクを実行することができます。)
---------------------------------------
Q: あなたはマラソンをしています。今3位の人を抜きました。あなたの今の順位は何位ですか?
A: ああ、マラソンをしているところですね。3位の人を抜いたので、今は2位です。
---------------------------------------
Q: ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。
A: Sure, here is a simple Python code snippet that creates a list of 10 random elements and sorts it:
```python
import random

# Create a list of 10 random elements
elements = [random.randint(0, 100) for _ in range(10)]

# Sort the list
elements.sort()

print(elements)
```
This code first imports the `random` module, which provides functionality for generating random numbers. It then creates a list of 10 random integers between 0 and 100 using a list comprehension. The `sort()` method is called on the list to sort it in ascending order. Finally, the sorted list is printed to the console.

Note that the `random.randint()` function generates a random integer within the specified range. The list comprehension `[... for _ in range(10)]` generates a list of 10 elements by calling the expression on its right 10 times. The `_` variable is used as a placeholder when the value generated by the expression is not needed.

I hope this helps! Let me know if you have any questions.

それぞれの結果はいかがでしょうか。2023/12までの知識を持っているようで、首相の名前など最近の知識で回答されているように思われます。

マラソンの順位は実は結構難しい質問で、3位の人を抜いたら正しくは3位です。
これは多くのモデルが単純な聞き方だと間違えるので、DBRXでも工夫が必要になりそうです。

コード生成については、正しい内容が生成されていました。

日本語もまあまあ使えていそうなのですが、量子化の影響か、若干おかしい回答も含まれていますね。

まとめ

久しぶりに大型のモデルを動かしてみました。
GPT-3.5相当、もしくは超える性能なのか、というのは今回のやり方では確認できないのですが、良さそうな印象です。
最近のオープンLLMはサイズが大きくなるトレンドになってきている気がするのですが、当面この流れなんですかね。
そんな中でBitNet b1.58なども出てきているので、パラメータサイズと量子化(と言っていいのか)がうまくバランスしながら、推論効率やメモリ消費量の最適化が進んでいけばいいなと期待しています。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?