1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ExLlamaV2用のlangchain カスタムChat Modelを作る

Last updated at Posted at 2023-11-23

導入

最近AWQフォーマットのローカルLLMを試す機会が多いのですが、GPTQなどの以前からある量子化フォーマットを利用したいときがあります。

ただ、私のやり方が悪いのか、Databricks上での使用においてAutoGPTQ(Transformers統合含む)をうまくGPUで動作させることができていません。

一方、以前取り上げたExLlamaはGPTQフォーマットのモデルを利用でき、またDatabricks上で難なく動かすことができます。ExLlamaV2もバージョンが上がってきており、以前より安定的に動くようになってきたようなので、今後楽に使えるようにするためにlangchainのカスタムChatモデルを作ります。

基本は、以前作成したAutoAWQやTransformersのChat Modelを作ったのと同様のやり方となります。

実装・検証はDatabricks上で行っています。DBRは14.1 ML、GPUクラスタを使いました。
ExLlamaV2は執筆時点でVersion 0.0.9です。

Step1. langchainのカスタムChat Modelを作る

こちらのExLlama V2 Exampleを基に、最低限のChatModel実装をしました。
exllamav2_chat.pyというファイルで保管します。

exllamav2_chat.py
import asyncio

import itertools

from typing import (
    Any,
    List,
    Union,
    Mapping,
    Tuple,
    Optional,
    Iterator,
    AsyncIterator,
)
from langchain.callbacks.manager import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    SystemMessage,
)
from langchain.schema import (
    ChatGeneration,
    ChatResult,
)
from langchain.schema.output import ChatGenerationChunk

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Cache_8bit,
    ExLlamaV2Tokenizer,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2StreamingGenerator,
    ExLlamaV2Sampler,
)


class ChatExllamaV2Model(BaseChatModel):
    """exllamav2 models.

    To use, you should have the ``exllamav2`` python package installed.

    Example:
        .. code-block:: python

    """

    exllama_config: ExLlamaV2Config
    """ ExLlamaV2 config """
    exllama_model: ExLlamaV2
    """ ExLlamaV2 pretrained Model """
    exllama_tokenizer: ExLlamaV2Tokenizer
    """ ExLlamaV2 Tokenizer """
    exllama_cache: Union[ExLlamaV2Cache, ExLlamaV2Cache_8bit]
    """ ExLLamaV2 Cache """

    # メッセージテンプレート
    human_message_template: str = "USER: {}\n"
    ai_message_template: str = "ASSISTANT: {}"
    system_message_template: str = "{}"

    do_sample: bool = True
    max_new_tokens: int = 64
    repetition_penalty: float = 1.1
    temperature: float = 1
    top_k: int = 50
    top_p: float = 0.8

    prompt_line_separator: str = "\n"

    @classmethod
    def from_model_dir(
        cls,
        model_dir: str,
        cache_8bit:bool=False,
        cache_max_seq_len: int = -1,
        low_mem = False,
        tokenizer_force_json = False,
        **kwargs: Any,
    ) -> "ChatExllamaV2Model":
        """Construct the exllamav2 model and tokenzier pipeline object from model_id"""

        # Initialize config
        config = ExLlamaV2Config()
        config.model_dir = model_dir
        if low_mem:
            config.set_low_mem()
        config.prepare()

        # Initialize model
        model = ExLlamaV2(config)

        # Initialize tokenizer
        tokenizer = ExLlamaV2Tokenizer(config, force_json=tokenizer_force_json)

        # cache
        cache = None
        if cache_8bit:
            cache = ExLlamaV2Cache_8bit(model, lazy=not model.loaded, max_seq_len=cache_max_seq_len)
        else:
            cache = ExLlamaV2Cache(model, lazy=not model.loaded, max_seq_len=cache_max_seq_len)

        # load model
        model.load_autosplit(cache)

        return cls(
            exllama_config=config,
            exllama_model=model,
            exllama_tokenizer=tokenizer,
            exllama_cache=cache,
            **kwargs,
        )

    @property
    def _llm_type(self) -> str:
        return "ChatExllamaV2Model"

    def _format_message_as_text(self, message: BaseMessage) -> str:
        if isinstance(message, ChatMessage):
            message_text = f"{self.prompt_line_separator}{message.role.capitalize()}: {message.content}"
        elif isinstance(message, HumanMessage):
            message_text = self.human_message_template.format(message.content)
        elif isinstance(message, AIMessage):
            message_text = self.ai_message_template.format(message.content)
        elif isinstance(message, SystemMessage):
            message_text = self.system_message_template.format(message.content)
        else:
            raise ValueError(f"Got unknown type {message}")
        return message_text

    def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
        return self.prompt_line_separator.join(
            [self._format_message_as_text(message) for message in messages]
        )

    def _generate_streamer(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ExLlamaV2StreamingGenerator:

        prompt = self._format_messages_as_text(messages)
        if self.verbose:
            print(prompt)

        _stop = stop or []
        # print("STOP:", _stop)

        generator = ExLlamaV2StreamingGenerator(
            self.exllama_model, self.exllama_cache, self.exllama_tokenizer
        )

        # Settings
        settings = ExLlamaV2Sampler.Settings()
        settings.temperature = self.temperature
        settings.top_k = self.top_k
        settings.top_p = self.top_p
        settings.token_repetition_penalty = self.repetition_penalty

        # Prompt
        # add_bos/add_eosの指定はプロンプトの種類に応じて考慮が必要かも。
        input_ids = self.exllama_tokenizer.encode(
            prompt, add_bos=True, add_eos=False, encode_special_tokens=True
        )

        prompt_tokens = input_ids.shape[-1]
        
        generator.set_stop_conditions(_stop + [self.exllama_tokenizer.eos_token_id])
        generator.begin_stream(input_ids, settings)

        return generator

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:

        streamer = self._generate_streamer(messages, stop, run_manager, **kwargs)

        texts = []
        generated_tokens = 0

        while True:
            chunk, eos, _ = streamer.stream()
            generated_tokens += 1
            texts.append(chunk)
            if run_manager:
                run_manager.on_llm_new_token(
                    chunk,
                    verbose=self.verbose,
                )

            if eos or generated_tokens == self.max_new_tokens:
                break

        chat_generation = ChatGeneration(message=AIMessage(content="".join(texts)))
        return ChatResult(
            generations=[chat_generation],
            llm_output={"completion_tokens": generated_tokens},
        )

    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        streamer = self._generate_streamer(messages, stop, run_manager, **kwargs)

        texts = []
        generated_tokens = 0

        while True:
            chunk, eos, _ = streamer.stream()
            generated_tokens += 1
            texts.append(chunk)
            if run_manager:
                await run_manager.on_llm_new_token(
                    chunk,
                    verbose=self.verbose,
                )

            await asyncio.sleep(0)
            if eos or generated_tokens == self.max_new_tokens:
                break

        chat_generation = ChatGeneration(message=AIMessage(content="".join(texts)))
        return ChatResult(
            generations=[chat_generation],
            llm_output={"completion_tokens": generated_tokens},
        )

    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:

        streamer = self._generate_streamer(messages, stop, run_manager, **kwargs)
        generated_tokens = 0

        while True:
            chunk, eos, _ = streamer.stream()
            generated_tokens += 1
            yield ChatGenerationChunk(message=AIMessageChunk(content=chunk))
            if run_manager:
                run_manager.on_llm_new_token(
                    chunk,
                    verbose=self.verbose,
                )
            if eos or generated_tokens == self.max_new_tokens:
                break

    async def _astream(
        self,
        messages: List[BaseMessage],
        stop: Union[List[str], None] = None,
        run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
        **kwargs: Any,
    ) -> AsyncIterator[ChatGenerationChunk]:

        streamer = self._generate_streamer(messages, stop, run_manager, **kwargs)
        generated_tokens = 0

        while True:
            chunk, eos, _ = streamer.stream()
            generated_tokens += 1
            yield ChatGenerationChunk(message=AIMessageChunk(content=chunk))
            if run_manager:
                await run_manager.on_llm_new_token(
                    chunk,
                    verbose=self.verbose,
                )
            await asyncio.sleep(0)
            if eos or generated_tokens == self.max_new_tokens:
                break

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "config": self.exllama_config,
            "model": self.exllama_model,
            "tokenizer": self.exllama_tokenizer,
        }

Step2. 簡単に試す

ノートブックを作成し、簡単に動作を確認してみます。

まずは必要なモジュールをインストール。

%pip install -U -qq transformers accelerate exllamav2 langchain

dbutils.library.restartPython()

Chainのプロンプトを作成。

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import (
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

template = """{question}"""
prompt = ChatPromptTemplate.from_messages(
    [
        HumanMessagePromptTemplate.from_template(template),
        AIMessagePromptTemplate.from_template(""),
    ]
)

カスタムチャットモデルを作成。
モデルは以下を使わせていただきました。OpenChat 3.5 7Bはなんだかんだ使いやすいので。
GPTQモデルは複数バージョンありますが、今回はmainブランチからダウンロードしてきたものを使います。

from transformers import AutoModelForCausalLM, AutoTokenizer
from exllamav2_chat import ChatExllamaV2Model

model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat_3.5-GPTQ"

chat_model = ChatExllamaV2Model.from_model_dir(
    model_path,
    human_message_template="GPT4 User: {}<|end_of_turn|>",
    ai_message_template="GPT4 Assistant: {}",
    temperature=0.1,
    max_new_tokens=1024,
)

Chainを作成して実行。

from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda

chat_model.max_new_tokens = 512

chain = (
    {"question": RunnablePassthrough()} 
    | prompt 
    | chat_model 
    | StrOutputParser()
)

async for s in chain.astream("Databrickとは何ですか?120文字以内で回答してください。"):
    print(s, end="", flush=True)
出力
データブリック(Databricks)は、アメリカ合衆国の企業で、オープンソースのデータ処理フレームワーク、Apache Sparkをベースにした分析プラットフォームを提供するコンパニーです。Cloud-basedのデータエンジニアリングおよびマシン学習プラットフォームを提供し、大規模なデータセットを効率的に処理できるようにしています。AWS, Azure, そしてGoogle Cloud Platform上で利用可能であり、データサイエンティストやデータエンジニアが高速なデータ分析を行うことができます。

ちゃんと動きました!

まとめ

ExLlama V2対応のlangchainカスタムチャットモデルを作成しました。

ちなみにExLlamaV2の推論速度は体感かなり速いです。GPTQでの推論ならこれ一択でいいのではないでしょうか。
今回のモデルだとVRAM使用量は推論実行後で5.4GBほど。
以前の記事でも触れましたが、EXL2という独自の量子化フォーマットもサポートしており、以下の方が積極的に変換モデルを公開されています。

近々、いろんなフレームワークでの推論速度比較をやってみたいと思います。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?