先に結論
-
langchain.chat_models.base.BaseChatModel
を継承したクラスを作成して、最低限_generate
を実装すれば動く。 - もう少しリッチにする場合は、
_agenerate
、_stream
、_astream
も実装するとよい。
導入
チャットボットのようなアプリをlangchainで作る場合、LLMsよりもChat Modelsのほうが何かと使い勝手がいい(気がする)。
とはいえ、Huggingfaceのモデルなど、ローカルLLMをChat Modelとして動かす仕組みは今のところ無いので、サクッとカスタムモデルを用意しようと考えたのですが、カスタムLLMのドキュメントはあっても、カスタムChatModelのドキュメントがない。
※ だいぶ前のIssueに質問が上がっていますが、望むような回答が付いていませんでした。
というわけで、github内のコードを見ながらやりかたを模索・実装してみます。
注意
完全自己流なので、参考にする場合は自己責任でお願いします。
方法の模索
Chat Modelのベースクラスが以下のソース内で定義されているので、このあたり読むと何となくわかります。
ローカルLLM対応にあたっては、以下が参考になりました。
他、同じフォルダ内のfake.py
やhuman.py
などのソースコードも有用でした。
結論としては冒頭の通り、
-
langchain.chat_models.base.BaseChatModel
を継承したクラスを作成して、最低限_generate
を実装すれば動く模様。(langchain.chat_models.base.SimpleChatModelでもOK) -
_agenerate
、_stream
、_astream
も標準実装が無いため、実装するといいと思います。
実装
というわけで試験実装やってみました。
題材として、こちらで作ったCTranslate2用のカスタムLLMをChat Modelとして作り直してみます。
コードは以下。長いので折り畳み。
試験実装
from ctranslate2 import Generator, GenerationStepResult
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing import (
Any,
List,
Union,
Dict,
Mapping,
Optional,
Iterable,
Iterator,
AsyncIterator,
Callable,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.output import ChatGenerationChunk
def _default_format_message_as_text(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"ユーザ: {message.content}"
elif isinstance(message, AIMessage):
message_text = f"システム: {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{message.content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
class ChatCTranslate2(BaseChatModel):
generator: Generator
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
format_message_as_text_func: Callable = Field(
default_factory=lambda: _default_format_message_as_text
)
max_length: int = 256
min_length: int = 0
repetition_penalty: float = 1.1
temperature: float = 1
topk: int = 1
topp: float = 1
no_repeat_ngram_size: int = 0
disable_unk: bool = False
static_prompt: Optional[List[str]] = None
cache_static_prompt: bool = True
prompt_line_separator: str = "\n"
@property
def _llm_type(self) -> str:
return "CTranslate2"
def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
return self.prompt_line_separator.join(
[self.format_message_as_text_func(message) for message in messages]
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = self._format_messages_as_text(messages)
step_results = self._generate_tokens(prompt)
output_ids = []
token_buffer = []
for step_result in step_results:
output_ids.append(step_result.token_id)
if run_manager:
if chunk := self._decode_with_buffer(step_result, token_buffer):
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=step_result.log_prob if step_result.log_prob else None,
)
text = self.tokenizer.decode(
output_ids,
skip_special_tokens=True,
)
chat_generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(
generations=[chat_generation],
llm_output={"completion_tokens": len(output_ids)},
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = self._format_messages_as_text(messages)
step_results = self._generate_tokens(prompt)
output_ids = []
token_buffer = []
for step_result in step_results:
output_ids.append(step_result.token_id)
if run_manager:
if chunk := self._decode_with_buffer(step_result, token_buffer):
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=step_result.log_prob if step_result.log_prob else None,
)
text = self.tokenizer.decode(
output_ids,
skip_special_tokens=True,
)
chat_generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(
generations=[chat_generation],
llm_output={"completion_tokens": len(output_ids)},
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
prompt = self._format_messages_as_text(messages)
step_results = self._generate_tokens(prompt)
token_buffer = []
for step_result in step_results:
if chunk := self._decode_with_buffer(step_result, token_buffer):
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=step_result.log_prob if step_result.log_prob else None,
)
async def _astream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
prompt = self._format_messages_as_text(messages)
step_results = self._generate_tokens(prompt)
token_buffer = []
for step_result in step_results:
if chunk := self._decode_with_buffer(step_result, token_buffer):
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=step_result.log_prob if step_result.log_prob else None,
)
def _generate_tokens(
self,
prompt: str,
) -> Iterable[GenerationStepResult]:
"""
Generates a tuple of tokens from the input text.
Removes special characters and punctuation, and splits the text on whitespace.
Args:
- text (str): The text to generate tokens from.
Returns:
- Tuple[str]: A tuple of tokens generated from the input text.
"""
# 推論の実行
tokens = self.tokenizer.convert_ids_to_tokens(
self.tokenizer.encode(prompt, add_special_tokens=False)
)
step_results = self.generator.generate_tokens(
tokens,
max_length=self.max_length,
min_length=self.min_length,
sampling_topk=self.topk,
sampling_topp=self.topp,
sampling_temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
return_log_prob=True,
no_repeat_ngram_size=self.no_repeat_ngram_size,
disable_unk=self.disable_unk,
static_prompt=self.static_prompt,
cache_static_prompt=self.cache_static_prompt,
)
return step_results
def _decode_with_buffer(
self, step_result: GenerationStepResult, token_buffer: list
) -> Union[ChatGenerationChunk, None]:
"""
Decodes the token buffer with the provided step_result,
returning a ChatGenerationChunk if a valid word is decoded.
Args:
- step_result (GenerationStepResult): The step result to decode the token buffer with.
- token_buffer (list): The token buffer to decode.
Returns:
- Union[ChatGenerationChunk, None]: Returns a ChatGenerationChunk containing the content of the decoded buffer,
along with the generation_info of step_result if a valid word is decoded, otherwise None.
"""
token_buffer.append(step_result.token_id)
word = self.tokenizer.decode(
token_buffer,
skip_special_tokens=True,
)
# 全て変換不能文字の場合、終了
if all(c == "�" for c in word):
return None
# step_resultのtokenが▁から始まる場合、スペースを付与する
if step_result.token.startswith("▁"):
word = " " + word
# 正常な文字が生成できた場合、バッファをクリア
token_buffer.clear()
return ChatGenerationChunk(
message=AIMessageChunk(content=word),
generation_info=self._convert_step_result_to_dict(step_result),
)
def _convert_step_result_to_dict(self, step_result: GenerationStepResult) -> dict:
"""GenerationStepResult型のオブジェクトを辞書型に詰め替えて返す"""
return {
"batch_id": step_result.batch_id,
"is_last": step_result.is_last,
"log_prob": step_result.log_prob,
"step": step_result.step,
"token_id": step_result.token_id,
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"generator": self.generator,
"tokenizer": self.tokenizer,
"max_length": self.max_length,
"min_length": self.min_length,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"sampling_topk": self.topk,
"sampling_topp": self.topp,
"no_repeat_ngram_size": self.no_repeat_ngram_size,
"disable_unk": self.disable_unk,
"static_prompt": self.static_prompt,
"cache_static_prompt": self.cache_static_prompt,
}
動作テストとして、HumanMessageなどChat Modelにおけるメッセージ指定を行います。
モデルはRinna社のrinna/japanese-gpt-neox-3.6b-instruction-ppo
を使用。
from chat_models.ctranslate2 import ChatCTranslate2
from langchain.schema import AIMessage, HumanMessage, SystemMessage
chat = ChatCTranslate2(
generator=generator,
tokenizer=tokenizer,
max_length=10,
prompt_line_separator="<NL>",
verbose=True,
)
messages = [
HumanMessage(content="明日の天気は?"),
AIMessage(content=""),
]
chat(messages)
AIMessage(content='明日の天気は、西から北にかけて晴れ', additional_kwargs={}, example=False)
動いていそうです。
_stream
も実装したので、コールバックを使わずにストリーミング表示もしてみます。
for chunk in chat.stream(messages):
print(chunk)
content='明日' additional_kwargs={} example=False
content='の' additional_kwargs={} example=False
content='天気' additional_kwargs={} example=False
content='は' additional_kwargs={} example=False
content='、' additional_kwargs={} example=False
content='西' additional_kwargs={} example=False
content='から' additional_kwargs={} example=False
content='北' additional_kwargs={} example=False
content='にかけて' additional_kwargs={} example=False
content='晴れ' additional_kwargs={} example=False
問題なさそうですね。
もう少しリッチにしたい場合(llm_outputの内容実装など)は、他のメソッドをオーバーライドする必要があると思いますが、最低限であればこれで使えそうです。
まとめ
ユースケースによりますが、Memoryの利用などChat Modelの方が利用しやすいところもあると思いますのでやってみました。
ただ、上の注意にも書いていますが、公式ドキュメントに記載がないので、あまりよくない実装になっているかもしれません。
langchainの利用は良し悪しありますし、正直ソースまで見ないとよくわからないところもあったりしますが、使いこなせると便利だと思います。
余談
CTranslate2のlangchain LLMが先週リリースされていました。
ただ、コードを見る限り現時点でストリーミング対応してなさそうなので、使用はちょっと見送り。。。