1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

langchainのカスタムChat Modelを作る

Posted at

先に結論

  • langchain.chat_models.base.BaseChatModelを継承したクラスを作成して、最低限_generateを実装すれば動く。
  • もう少しリッチにする場合は、_agenerate_stream_astreamも実装するとよい。

導入

チャットボットのようなアプリをlangchainで作る場合、LLMsよりもChat Modelsのほうが何かと使い勝手がいい(気がする)。
とはいえ、Huggingfaceのモデルなど、ローカルLLMをChat Modelとして動かす仕組みは今のところ無いので、サクッとカスタムモデルを用意しようと考えたのですが、カスタムLLMのドキュメントはあっても、カスタムChatModelのドキュメントがない。
※ だいぶ前のIssueに質問が上がっていますが、望むような回答が付いていませんでした。

というわけで、github内のコードを見ながらやりかたを模索・実装してみます。

注意
完全自己流なので、参考にする場合は自己責任でお願いします。

方法の模索

Chat Modelのベースクラスが以下のソース内で定義されているので、このあたり読むと何となくわかります。

ローカルLLM対応にあたっては、以下が参考になりました。

他、同じフォルダ内のfake.pyhuman.pyなどのソースコードも有用でした。

結論としては冒頭の通り、

  • langchain.chat_models.base.BaseChatModelを継承したクラスを作成して、最低限_generateを実装すれば動く模様。(langchain.chat_models.base.SimpleChatModelでもOK)
  • _agenerate_stream_astreamも標準実装が無いため、実装するといいと思います。

実装

というわけで試験実装やってみました。
題材として、こちらで作ったCTranslate2用のカスタムLLMをChat Modelとして作り直してみます。

コードは以下。長いので折り畳み。

試験実装
from ctranslate2 import Generator, GenerationStepResult
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from typing import (
    Any,
    List,
    Union,
    Dict,
    Mapping,
    Optional,
    Iterable,
    Iterator,
    AsyncIterator,
    Callable,
)
from langchain.callbacks.manager import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    SystemMessage,
)
from langchain.schema import (
    ChatGeneration,
    ChatResult,
)
from langchain.schema.output import ChatGenerationChunk


def _default_format_message_as_text(message: BaseMessage) -> str:
    if isinstance(message, ChatMessage):
        message_text = f"\n\n{message.role.capitalize()}: {message.content}"
    elif isinstance(message, HumanMessage):
        message_text = f"ユーザ: {message.content}"
    elif isinstance(message, AIMessage):
        message_text = f"システム: {message.content}"
    elif isinstance(message, SystemMessage):
        message_text = f"{message.content}"
    else:
        raise ValueError(f"Got unknown type {message}")
    return message_text


class ChatCTranslate2(BaseChatModel):

    generator: Generator
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
    format_message_as_text_func: Callable = Field(
        default_factory=lambda: _default_format_message_as_text
    )

    max_length: int = 256
    min_length: int = 0
    repetition_penalty: float = 1.1
    temperature: float = 1
    topk: int = 1
    topp: float = 1
    no_repeat_ngram_size: int = 0
    disable_unk: bool = False

    static_prompt: Optional[List[str]] = None
    cache_static_prompt: bool = True

    prompt_line_separator: str = "\n"

    @property
    def _llm_type(self) -> str:
        return "CTranslate2"

    def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
        return self.prompt_line_separator.join(
            [self.format_message_as_text_func(message) for message in messages]
        )

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:

        prompt = self._format_messages_as_text(messages)
        step_results = self._generate_tokens(prompt)

        output_ids = []
        token_buffer = []

        for step_result in step_results:

            output_ids.append(step_result.token_id)

            if run_manager:
                if chunk := self._decode_with_buffer(step_result, token_buffer):

                    run_manager.on_llm_new_token(
                        chunk.text,
                        verbose=self.verbose,
                        logprobs=step_result.log_prob if step_result.log_prob else None,
                    )

        text = self.tokenizer.decode(
            output_ids,
            skip_special_tokens=True,
        )

        chat_generation = ChatGeneration(message=AIMessage(content=text))
        return ChatResult(
            generations=[chat_generation],
            llm_output={"completion_tokens": len(output_ids)},
        )

    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:

        prompt = self._format_messages_as_text(messages)
        step_results = self._generate_tokens(prompt)

        output_ids = []
        token_buffer = []

        for step_result in step_results:

            output_ids.append(step_result.token_id)

            if run_manager:
                if chunk := self._decode_with_buffer(step_result, token_buffer):

                    await run_manager.on_llm_new_token(
                        chunk.text,
                        verbose=self.verbose,
                        logprobs=step_result.log_prob if step_result.log_prob else None,
                    )

        text = self.tokenizer.decode(
            output_ids,
            skip_special_tokens=True,
        )

        chat_generation = ChatGeneration(message=AIMessage(content=text))
        return ChatResult(
            generations=[chat_generation],
            llm_output={"completion_tokens": len(output_ids)},
        )

    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:

        prompt = self._format_messages_as_text(messages)
        step_results = self._generate_tokens(prompt)

        token_buffer = []
        for step_result in step_results:

            if chunk := self._decode_with_buffer(step_result, token_buffer):
                yield chunk

                if run_manager:
                    run_manager.on_llm_new_token(
                        chunk.text,
                        verbose=self.verbose,
                        logprobs=step_result.log_prob if step_result.log_prob else None,
                    )

    async def _astream(
        self,
        messages: List[BaseMessage],
        stop: Union[List[str], None] = None,
        run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
        **kwargs: Any,
    ) -> AsyncIterator[ChatGenerationChunk]:
        prompt = self._format_messages_as_text(messages)
        step_results = self._generate_tokens(prompt)

        token_buffer = []
        for step_result in step_results:

            if chunk := self._decode_with_buffer(step_result, token_buffer):
                yield chunk

                if run_manager:
                    await run_manager.on_llm_new_token(
                        chunk.text,
                        verbose=self.verbose,
                        logprobs=step_result.log_prob if step_result.log_prob else None,
                    )

    def _generate_tokens(
        self,
        prompt: str,
    ) -> Iterable[GenerationStepResult]:
        """
        Generates a tuple of tokens from the input text.
        Removes special characters and punctuation, and splits the text on whitespace.

        Args:
        - text (str): The text to generate tokens from.

        Returns:
        - Tuple[str]: A tuple of tokens generated from the input text.
        """

        # 推論の実行
        tokens = self.tokenizer.convert_ids_to_tokens(
            self.tokenizer.encode(prompt, add_special_tokens=False)
        )

        step_results = self.generator.generate_tokens(
            tokens,
            max_length=self.max_length,
            min_length=self.min_length,
            sampling_topk=self.topk,
            sampling_topp=self.topp,
            sampling_temperature=self.temperature,
            repetition_penalty=self.repetition_penalty,
            return_log_prob=True,
            no_repeat_ngram_size=self.no_repeat_ngram_size,
            disable_unk=self.disable_unk,
            static_prompt=self.static_prompt,
            cache_static_prompt=self.cache_static_prompt,
        )

        return step_results

    def _decode_with_buffer(
        self, step_result: GenerationStepResult, token_buffer: list
    ) -> Union[ChatGenerationChunk, None]:
        """
        Decodes the token buffer with the provided step_result,
        returning a ChatGenerationChunk if a valid word is decoded.

        Args:
        - step_result (GenerationStepResult): The step result to decode the token buffer with.
        - token_buffer (list): The token buffer to decode.

        Returns:
        - Union[ChatGenerationChunk, None]: Returns a ChatGenerationChunk containing the content of the decoded buffer,
            along with the generation_info of step_result if a valid word is decoded, otherwise None.
        """

        token_buffer.append(step_result.token_id)
        word = self.tokenizer.decode(
            token_buffer,
            skip_special_tokens=True,
        )

        # 全て変換不能文字の場合、終了
        if all(c == "" for c in word):
            return None

        # step_resultのtokenが▁から始まる場合、スペースを付与する
        if step_result.token.startswith(""):
            word = " " + word

        # 正常な文字が生成できた場合、バッファをクリア
        token_buffer.clear()

        return ChatGenerationChunk(
            message=AIMessageChunk(content=word),
            generation_info=self._convert_step_result_to_dict(step_result),
        )

    def _convert_step_result_to_dict(self, step_result: GenerationStepResult) -> dict:
        """GenerationStepResult型のオブジェクトを辞書型に詰め替えて返す"""
        return {
            "batch_id": step_result.batch_id,
            "is_last": step_result.is_last,
            "log_prob": step_result.log_prob,
            "step": step_result.step,
            "token_id": step_result.token_id,
        }

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "generator": self.generator,
            "tokenizer": self.tokenizer,
            "max_length": self.max_length,
            "min_length": self.min_length,
            "repetition_penalty": self.repetition_penalty,
            "temperature": self.temperature,
            "sampling_topk": self.topk,
            "sampling_topp": self.topp,
            "no_repeat_ngram_size": self.no_repeat_ngram_size,
            "disable_unk": self.disable_unk,
            "static_prompt": self.static_prompt,
            "cache_static_prompt": self.cache_static_prompt,
        }

動作テストとして、HumanMessageなどChat Modelにおけるメッセージ指定を行います。
モデルはRinna社のrinna/japanese-gpt-neox-3.6b-instruction-ppoを使用。

from chat_models.ctranslate2 import ChatCTranslate2
from langchain.schema import AIMessage, HumanMessage, SystemMessage

chat = ChatCTranslate2(
    generator=generator,
    tokenizer=tokenizer,
    max_length=10,
    prompt_line_separator="<NL>",
    verbose=True,
)

messages = [
    HumanMessage(content="明日の天気は?"),
    AIMessage(content=""),
]
chat(messages)
結果
AIMessage(content='明日の天気は、西から北にかけて晴れ', additional_kwargs={}, example=False)

動いていそうです。
_streamも実装したので、コールバックを使わずにストリーミング表示もしてみます。

for chunk in chat.stream(messages):
    print(chunk)
結果
content='明日' additional_kwargs={} example=False
content='の' additional_kwargs={} example=False
content='天気' additional_kwargs={} example=False
content='は' additional_kwargs={} example=False
content='、' additional_kwargs={} example=False
content='西' additional_kwargs={} example=False
content='から' additional_kwargs={} example=False
content='北' additional_kwargs={} example=False
content='にかけて' additional_kwargs={} example=False
content='晴れ' additional_kwargs={} example=False

問題なさそうですね。

もう少しリッチにしたい場合(llm_outputの内容実装など)は、他のメソッドをオーバーライドする必要があると思いますが、最低限であればこれで使えそうです。

まとめ

ユースケースによりますが、Memoryの利用などChat Modelの方が利用しやすいところもあると思いますのでやってみました。
ただ、上の注意にも書いていますが、公式ドキュメントに記載がないので、あまりよくない実装になっているかもしれません。

langchainの利用は良し悪しありますし、正直ソースまで見ないとよくわからないところもあったりしますが、使いこなせると便利だと思います。

余談

CTranslate2のlangchain LLMが先週リリースされていました。
ただ、コードを見る限り現時点でストリーミング対応してなさそうなので、使用はちょっと見送り。。。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?