0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Huggingface TransformersのAWQ統合を試す in Databricks

Posted at

今週はLLM関連いろいろありすぎてワケワカメ状態(楽しい)。
昨日出た日本語LLMなどはTheBloke兄貴のリポジトリに上がったら試そうと思います。

導入

Huggingface Transformers v4.35.0がリリースされ、今回からAWQフォーマットの量子化モデルが統合されました。

上記リンク先のベンチマーク結果に記載があるように、AWQフォーマットは以下のような特徴があります。

この結果から、AWQ量子化法は推論、テキスト生成において最も高速な量子化法であり、テキスト生成におけるピークメモリも最も少ない。しかし、AWQはバッチサイズあたりの前方遅延が最も大きい。

普段の記事でAWQフォーマットのモデルをよく使っているので、Transformers上でも使ってみます。
公式にColaboのサンプルコードがあるため、これを大いに参考にしています。

やってみる

まずは必要なモジュールをインストール。

%pip install -U transformers accelerate
%pip install autoawq

dbutils.library.restartPython()

モデルとトークナイザをロード。
モデルは公式サンプルに準じてTheBloke/Llama-2-13B-chat-AWQを使用します。
revisionでrefs/pr/4を指定しているのがポイント。 これは後から補足します。

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextStreamer

model_id = "TheBloke/Llama-2-13B-chat-AWQ"

tokenizer = AutoTokenizer.from_pretrained(model_id, revision="refs/pr/4")
model = AutoModelForCausalLM.from_pretrained(model_id, revision="refs/pr/4", device_map="cuda")
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

推論を実行。TextStreamerを指定しているため、標準出力に結果がストリーム出力されます。

text = "User:\nHello can you provide me with top-3 cool places to visit in Paris?\n\nAssistant:\n"
inputs = tokenizer(text, return_tensors="pt").to(0)

out = model.generate(**inputs, max_new_tokens=300, streamer=streamer)
print(tokenizer.decode(out[0], skip_special_tokens=True))
出力
Bonjour! Paris, the City of Light, is full of cool places to visit. Here are my top three recommendations:

1. The Louvre Museum: This iconic museum is home to some of the world's most famous artworks, including the Mona Lisa. The building itself is also a work of art, with stunning architecture and gardens.
2. Montmartre: This charming neighborhood is known for its bohemian vibe, street artists, and stunning views of the city from the top of the hill. Be sure to visit the Basilique du Sacré-Cœur, a beautiful white church that sits at the summit.
3. Palace of Versailles: Just a short train ride from Paris, the Palace of Versailles is a must-see for anyone interested in history, art, and architecture. The palace's opulent interiors, gardens, and fountain shows make it a truly unforgettable experience.

I hope you enjoy your time in Paris! Do you have any other questions or preferences?

問題なく推論出力できました。

注意点

他のモデルを使おうとした際にハマったのですが、AWQフォーマットのデータを読む場合、config.json内にquantization_configの設定が含まれている必要があります。
この設定が無い場合、正しく重みをロードすることができません。

AWQフォーマットを利用する際に、間違いなくお世話になるであろうTheBloke兄貴のリポジトリには、現状ほとんどquantization_configは含まれておらず、正しく重みをロードできません。

上記例で、モデルのfrom_pretrained時にrevisionを指定していたのは、この設定が含まれるリビジョンのconfig.jsonを読んでいたためでした。

この現象の回避策として、from_pretrained時にquantization_configを含んだconfigを渡すなどの方法があります。
一例として、TheBloke/zephyr-7B-beta-AWQをロードしてみます。

from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, GenerationConfig
from transformers import TextStreamer, TextIteratorStreamer
from transformers import AutoConfig, AwqConfig

model_id = "TheBloke/zephyr-7B-beta-AWQ"

# AWQを読み込むための設定を手動作成
quantization_config = AwqConfig(
    bits=4,
    group_size=128,
    zero_point=True,
    version="gemm",
    backend="autoawq",
)

# 既存のconfigにquantization_configを追加
config = AutoConfig.from_pretrained(model_id)
config.quantization_config = quantization_config.to_dict()

# configを指定してロード
model = AutoModelForCausalLM.from_pretrained(model_id, config=config, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

推論処理をラップして実行してみます。

def generate_stream_text(prompt: str, max_new_tokens: int = 512) -> str:

    tokens = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

    generation_config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_k=40,
        top_p=0.95,
        temperature=0.7,
        eos_token_id=model.config.eos_token_id,
        pad_token_id=model.config.pad_token_id,
    )

    generation_output = model.generate(
        tokens,
        streamer=streamer,
        generation_config=generation_config,
    )

    return tokenizer.decode(generation_output[0], skip_special_tokens=True)
prompt = """<|system|>
You are a friendly chatbot.</s>
<|user|>
Databricksとは何ですか?</s>
<|assistant|>
"""

max_new_tokens = 128
_ = generate_stream_text(prompt, max_new_tokens)
出力
Databricksは、大データ処理に特化したクラウドベースのプラットフォームです。このプラットフォームは、アパチ(Apache) Sparkとそのエコシステムを基礎としています。Databricksは、データの取り扱い、分析、機械学習を可能にする大量のコンピューティングパワーと

問題なくロードから推論まで実行できました。

まとめ

TransformersがAWQをサポートした(バックでAutoAWQが動作するようになった)ため、これまで以上に多彩なことができるようになりました。
例えばFlash Attentionとの組み合わせなど、引き続き試してみたいと思っています。

一方、まだリリースされた機能ばかりなので、不具合があったり、使い方にクセがあったりするのではないかと思いますので、注意が必要ですね。

他にもExLlamav2のカーネルが統合されたり、マルチモーダル系の処理が統合されたりと、今回のバージョンアップ、かなり多彩なようです。
適度にキャッチアップしていきたいと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?