導入
以前、AutoAWQを使ってAWQフォーマットのモデルを読み込んだりしてみました。
langchainで使えた方がいろいろ便利なのですが、AutoAWQは見たところlangchainのLLM/ChatModel対応がされていないようなので、カスタムChat Modelクラスを作って見ます。
なお、以前CTranslate2対応のカスタムChat Modelクラス作成をした記事はこちら。
基本的にはこれを踏襲します。
実行・検証はDatabricks上で行っています。
必要なモジュールのインストール
AutoAWQやlangchainなど、必要なモジュールをインストール。
%pip install -U autoawq transformers accelerate langchain
dbutils.library.restartPython()
カスタムLLMモデルを作る
最低限の実装はざっくりと以下のようなコードになります。
エラー処理等、かなり省略しているので注意。一応、ストリーミング出力対応です。
from awq import AutoAWQForCausalLM
from awq.models.base import BaseAWQForCausalLM
from transformers import (
AutoTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
TextIteratorStreamer,
)
from threading import Thread
import asyncio
from typing import (
Any,
List,
Union,
Mapping,
Tuple,
Optional,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema import (
ChatGeneration,
ChatResult,
)
def load_pretrained_model(
model_path: str,
max_new_tokens: int = None,
fuse_layers: bool = True,
safetensors: bool = True,
use_fast: bool = False,
trust_remote_code: bool = False,
) -> Tuple[BaseAWQForCausalLM, Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]:
"""事前学習済みモデルとTokenizerを返す"""
# Load model
model = AutoAWQForCausalLM.from_quantized(
model_path,
max_new_tokens=max_new_tokens,
fuse_layers=fuse_layers,
trust_remote_code=trust_remote_code,
safetensors=safetensors,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_faset=use_fast,
trust_remote_code=trust_remote_code,
)
return model, tokenizer
class ChatAutoAWQ(BaseChatModel):
"""AWQ(AutoAWQ) models.
To use, you should have the ``autoawq`` python package installed.
Example:
.. code-block:: python
generator, tokenizer = load_pretrained_model(model_path)
hf = ChatAutoAWQ(generator=generator, tokenizer=tokenizer)
"""
generator: BaseAWQForCausalLM
""" AutoAWQ Model """
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
""" Tokenizer """
# メッセージテンプレート
human_message_template: str = "ユーザ: {}"
ai_message_template: str = "システム: {}"
system_message_template: str = "{}"
do_sample: bool = True
max_new_tokens: int = 64
repetition_penalty: float = 1.1
temperature: float = 1
top_k: int = 1
top_p: float = 0.95
prompt_line_separator: str = "\n"
@property
def _llm_type(self) -> str:
return "AutoAWQ"
def _format_message_as_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"{self.prompt_line_separator}{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = self.human_message_template.format(message.content)
elif isinstance(message, AIMessage):
message_text = self.ai_message_template.format(message.content)
elif isinstance(message, SystemMessage):
message_text = self.system_message_template.format(message.content)
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
return self.prompt_line_separator.join(
[self._format_message_as_text(message) for message in messages]
)
def _generate_stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> TextIteratorStreamer:
prompt = self._format_messages_as_text(messages)
tokens = self.tokenizer(
prompt,
# add_special_tokens=True,
return_tensors="pt",
).input_ids.cuda()
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generate output
generation_kwargs = dict(
inputs=tokens,
do_sample=self.do_sample,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
max_new_tokens=self.max_new_tokens,
streamer=streamer,
num_return_sequences=1,
)
thread = Thread(target=self.generator.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
streamer = self._generate_stream(messages, stop, run_manager, **kwargs)
text = ""
count = 0
for new_text in streamer:
if not new_text:
continue
text += new_text
count += 1
if run_manager:
run_manager.on_llm_new_token(
new_text,
verbose=self.verbose,
)
chat_generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(
generations=[chat_generation],
llm_output={"completion_tokens": count},
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
streamer = self._generate_stream(messages, stop, run_manager, **kwargs)
text = ""
count = 0
for new_text in streamer:
if not new_text:
continue
text += new_text
count += 1
# await asyncio.sleep(0)
if run_manager:
await run_manager.on_llm_new_token(
new_text,
verbose=self.verbose,
)
chat_generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(
generations=[chat_generation],
llm_output={"completion_tokens": count},
)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"generator": self.generator,
"tokenizer": self.tokenizer,
}
試してみる
今回はせっかくなので、HuggingfaceのZephyr 7B Alpha
のAWQモデルを使ってみます。
Zephyrについては以下の記事を参照してください。
Mistral 7Bベースでファインチューニングしたモデルです。
こちらから事前にモデルをダウンロードしておき、今回作ったクラスで読み込みます。
UC_VOLUME = "/Volumes/モデルのスナップショット保管場所"
model_dir = "/models--TheBloke--zephyr-7B-alpha-AWQ"
model_path = f"{UC_VOLUME}{model_dir}"
from chat_models.autoawq_chat import load_pretrained_model
model, tokenizer = load_pretrained_model(model_path, fuse_layers=False)
from chat_models.autoawq_chat import ChatAutoAWQ
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
llm = ChatAutoAWQ(
generator=model,
tokenizer=tokenizer,
system_message_template="<|system|>\n{}</s>",
human_message_template="<|user|>\n{}</s>",
ai_message_template="<|assistant|>{}",
max_new_tokens=256,
)
result = llm.predict("東京の明日の天気は?", callbacks=[StreamingStdOutCallbackHandler()])
<|assistant|>
今日は東京の天気は、晴れです。明日の天気は、晴れと少しの雨が予測されています。
動きました。実際にはストリームで出力されます。
もう少し、いろいろ試してみましょう。
Chat Modelなので、それらしく指定します。
from langchain.schema.messages import (
AIMessage,
HumanMessage,
)
output = llm(
[
HumanMessage(content="架空の生き物をデザインし、その特徴や能力について日本語で説明してください。"),
AIMessage(content="\n"),
]
)
print(output.content)
「エレクトリック・フロー」という名前の生き物は、電気の力を活用している、幻的な生き物です。
この生き物は、電気の流れを感知し、その流れを利用して動きます。電気の強さに応じて、速度や方向を変えることができます。
エレクトリック・フローは、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張
なんか繰り返してますね。ペナルティの設定が必要かな。
コードも生成してみます。
output = llm(
[
HumanMessage(content="ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。"),
AIMessage(content="\n"),
]
)
print(output.content)
Here's a Python code to generate a list of 10 random numbers and sort them:
```python
import random
numbers = [random.randint(1, 100) for _ in range(10)]
numbers.sort()
print(numbers)
``
Explanation:
1. We first import the `random` module to generate random numbers.
2. We create an empty list `numbers` and use a list comprehension to generate 10 random numbers between 1 and 100 (inclusive) using the `randint()` function from the `random` module.
3. We then sort the list using the `sort()` method.
4. Finally, we print the sorted list.
できました。生成されるコードの動作確認もしてみましたが、きちんと動きます。
まとめ
AWQモデルをlangchainで使えるようにするためにカスタムChat Modelを作って見ました。
しかし、Mistral 7Bのファインチューニングモデルがワンサカ出てきていて、比較的少ないパラメータで高い性能を発揮しているのが面白いです。
全体的に日本語がイマイチなところがあり、日本語が得意なファインチューニングモデルが出てくることを願ってます。