0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

AutoAWQ用のlangchainのカスタムChat Modelを作る

Last updated at Posted at 2023-10-16

導入

以前、AutoAWQを使ってAWQフォーマットのモデルを読み込んだりしてみました。
langchainで使えた方がいろいろ便利なのですが、AutoAWQは見たところlangchainのLLM/ChatModel対応がされていないようなので、カスタムChat Modelクラスを作って見ます。

なお、以前CTranslate2対応のカスタムChat Modelクラス作成をした記事はこちら。
基本的にはこれを踏襲します。

実行・検証はDatabricks上で行っています。

必要なモジュールのインストール

AutoAWQやlangchainなど、必要なモジュールをインストール。

%pip install -U autoawq transformers accelerate langchain

dbutils.library.restartPython()

カスタムLLMモデルを作る

最低限の実装はざっくりと以下のようなコードになります。
エラー処理等、かなり省略しているので注意。一応、ストリーミング出力対応です。

from awq import AutoAWQForCausalLM
from awq.models.base import BaseAWQForCausalLM
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
    TextIteratorStreamer,
)
from threading import Thread
import asyncio

from typing import (
    Any,
    List,
    Union,
    Mapping,
    Tuple,
    Optional,
)
from langchain.callbacks.manager import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    SystemMessage,
)
from langchain.schema import (
    ChatGeneration,
    ChatResult,
)


def load_pretrained_model(
    model_path: str,
    max_new_tokens: int = None,
    fuse_layers: bool = True,
    safetensors: bool = True,
    use_fast: bool = False,
    trust_remote_code: bool = False,
) -> Tuple[BaseAWQForCausalLM, Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]:
    """事前学習済みモデルとTokenizerを返す"""
    # Load model
    model = AutoAWQForCausalLM.from_quantized(
        model_path,
        max_new_tokens=max_new_tokens,
        fuse_layers=fuse_layers,
        trust_remote_code=trust_remote_code,
        safetensors=safetensors,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        use_faset=use_fast,
        trust_remote_code=trust_remote_code,
    )

    return model, tokenizer


class ChatAutoAWQ(BaseChatModel):
    """AWQ(AutoAWQ) models.

    To use, you should have the ``autoawq`` python package installed.

    Example:
        .. code-block:: python

            generator, tokenizer = load_pretrained_model(model_path)
            hf = ChatAutoAWQ(generator=generator, tokenizer=tokenizer)
    """

    generator: BaseAWQForCausalLM
    """ AutoAWQ Model """
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
    """ Tokenizer """

    # メッセージテンプレート
    human_message_template: str = "ユーザ: {}"
    ai_message_template: str = "システム: {}"
    system_message_template: str = "{}"

    do_sample: bool = True
    max_new_tokens: int = 64
    repetition_penalty: float = 1.1
    temperature: float = 1
    top_k: int = 1
    top_p: float = 0.95

    prompt_line_separator: str = "\n"

    @property
    def _llm_type(self) -> str:
        return "AutoAWQ"

    def _format_message_as_text(self, message: BaseMessage) -> str:
        if isinstance(message, ChatMessage):
            message_text = f"{self.prompt_line_separator}{message.role.capitalize()}: {message.content}"
        elif isinstance(message, HumanMessage):
            message_text = self.human_message_template.format(message.content)
        elif isinstance(message, AIMessage):
            message_text = self.ai_message_template.format(message.content)
        elif isinstance(message, SystemMessage):
            message_text = self.system_message_template.format(message.content)
        else:
            raise ValueError(f"Got unknown type {message}")
        return message_text

    def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
        return self.prompt_line_separator.join(
            [self._format_message_as_text(message) for message in messages]
        )

    def _generate_stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> TextIteratorStreamer:

        prompt = self._format_messages_as_text(messages)

        tokens = self.tokenizer(
            prompt,
            # add_special_tokens=True,
            return_tensors="pt",
        ).input_ids.cuda()

        streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)

        # Generate output
        generation_kwargs = dict(
            inputs=tokens,
            do_sample=self.do_sample,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            max_new_tokens=self.max_new_tokens,
            streamer=streamer,
            num_return_sequences=1,
        )
        thread = Thread(target=self.generator.generate, kwargs=generation_kwargs)
        thread.start()

        return streamer

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:

        streamer = self._generate_stream(messages, stop, run_manager, **kwargs)

        text = ""
        count = 0
        
        for new_text in streamer:
            if not new_text:
                continue

            text += new_text
            count += 1

            if run_manager:
                run_manager.on_llm_new_token(
                    new_text,
                    verbose=self.verbose,
                )

        chat_generation = ChatGeneration(message=AIMessage(content=text))
        return ChatResult(
            generations=[chat_generation],
            llm_output={"completion_tokens": count},
        )

    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        streamer = self._generate_stream(messages, stop, run_manager, **kwargs)

        text = ""
        count = 0

        for new_text in streamer:

            if not new_text:
                continue

            text += new_text
            count += 1

            # await asyncio.sleep(0)
            if run_manager:
                await run_manager.on_llm_new_token(
                    new_text,
                    verbose=self.verbose,
                )

        chat_generation = ChatGeneration(message=AIMessage(content=text))
        return ChatResult(
            generations=[chat_generation],
            llm_output={"completion_tokens": count},
        )

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "generator": self.generator,
            "tokenizer": self.tokenizer,
        }

試してみる

今回はせっかくなので、HuggingfaceのZephyr 7B AlphaのAWQモデルを使ってみます。
Zephyrについては以下の記事を参照してください。

Mistral 7Bベースでファインチューニングしたモデルです。

こちらから事前にモデルをダウンロードしておき、今回作ったクラスで読み込みます。

UC_VOLUME = "/Volumes/モデルのスナップショット保管場所"
model_dir = "/models--TheBloke--zephyr-7B-alpha-AWQ"

model_path = f"{UC_VOLUME}{model_dir}"
from chat_models.autoawq_chat import load_pretrained_model

model, tokenizer = load_pretrained_model(model_path, fuse_layers=False)
from chat_models.autoawq_chat import ChatAutoAWQ
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler


llm = ChatAutoAWQ(
    generator=model,
    tokenizer=tokenizer,
    system_message_template="<|system|>\n{}</s>",
    human_message_template="<|user|>\n{}</s>",
    ai_message_template="<|assistant|>{}",
    max_new_tokens=256,
)

result = llm.predict("東京の明日の天気は?", callbacks=[StreamingStdOutCallbackHandler()])
出力

<|assistant|>
今日は東京の天気は、晴れです。明日の天気は、晴れと少しの雨が予測されています。

動きました。実際にはストリームで出力されます。

もう少し、いろいろ試してみましょう。
Chat Modelなので、それらしく指定します。

from langchain.schema.messages import (
    AIMessage,
    HumanMessage,
) 

output = llm(
    [
        HumanMessage(content="架空の生き物をデザインし、その特徴や能力について日本語で説明してください。"),
        AIMessage(content="\n"),
    ]
)
print(output.content)
出力
「エレクトリック・フロー」という名前の生き物は、電気の力を活用している、幻的な生き物です。

この生き物は、電気の流れを感知し、その流れを利用して動きます。電気の強さに応じて、速度や方向を変えることができます。

エレクトリック・フローは、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張り、電気の強さに応じて、縁を張

なんか繰り返してますね。ペナルティの設定が必要かな。

コードも生成してみます。

output = llm(
    [
        HumanMessage(content="ランダムな10個の要素からなるリストを作成してソートするコードをPythonで書いてください。"),
        AIMessage(content="\n"),
    ]
)
print(output.content)
出力
Here's a Python code to generate a list of 10 random numbers and sort them:

```python
import random

numbers = [random.randint(1, 100) for _ in range(10)]
numbers.sort()

print(numbers)
``

Explanation:

1. We first import the `random` module to generate random numbers.

2. We create an empty list `numbers` and use a list comprehension to generate 10 random numbers between 1 and 100 (inclusive) using the `randint()` function from the `random` module.

3. We then sort the list using the `sort()` method.

4. Finally, we print the sorted list.

できました。生成されるコードの動作確認もしてみましたが、きちんと動きます。

まとめ

AWQモデルをlangchainで使えるようにするためにカスタムChat Modelを作って見ました。

しかし、Mistral 7Bのファインチューニングモデルがワンサカ出てきていて、比較的少ないパラメータで高い性能を発揮しているのが面白いです。
全体的に日本語がイマイチなところがあり、日本語が得意なファインチューニングモデルが出てくることを願ってます。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?