Claude3とかプロプライエタリなサービスが盛り上がっている中でローカルLLMの記事を書くという逆張り。。。
導入
不勉強で初めて知ったのですが、LLMに構造的な出力(例:JSONスキーマに則ったフォーマット出力)をさせるためのパッケージとしてLM Format Enforcerというものが出ていました。
LLM Agentを構築するためにはOpenAI Function callingなどを用いて実行ツールの選択と実行を行うと思いますが、ローカルLLMにおいてこれを実行するためにはいくつかの工夫が必要です。
LangChainでは様々なエージェント処理をカバーしており、例えばJSON出力するChatAgentの例が掲載されています。
ただ、性能の低いLLMを使うとそもそもJSON出力が安定せず、エラーハンドリングやプロンプトの工夫などが必要でした。
最近は以下の記事のようにSGLangのJSONデーコード機能を使って出力形式を強制する実装をしていました。
しかし、SGLangというパッケージに推論部含めて縛られること、SGLang自体発展途上のパッケージであるなどの理由から別種の手段が無いかなと考えていました。
(Outlinesやguidanceの存在は知っていたのですが、ちょっと使いづらいなと・・・)
その中で、LM Format Enforcerは多くのパッケージとの連携など使いやすそうだったので、今回試してみます。
実装・検証はDatabricks on AWSで行いました。
DBRは14.3ML、クラスタタイプをg4dn.xlargeにしています。
Step1. パッケージインストール
ノートブックを作成し、LLM Format Enforcerをインストール。
また、今回はExLlamaV2を使ってローカルLLMの推論を実施させますので、関連パッケージをインストールします。
%pip install -U -qq transformers accelerate "exllamav2>=0.0.15" langchain lm-format-enforcer pydantic
dbutils.library.restartPython()
Step2. ExLlamaV2のChatModelClassを準備
簡易(?)的にExLlamaV2をLangChainのChatModelとして利用するラッパークラスを作成し、exllamav2_chat.pyという名前でノートブックと同じフォルダに保存。
長いので折り畳み。
exllamav2_chat.py
import asyncio
import itertools
from typing import (
Any,
List,
Union,
Mapping,
Tuple,
Optional,
Iterator,
AsyncIterator,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs.chat_result import ChatGeneration, ChatResult
from langchain_core.outputs.chat_generation import ChatGenerationChunk
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
ExLlamaV2BaseGenerator,
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler,
)
from exllamav2.generator.filters import ExLlamaV2PrefixFilter
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer import JsonSchemaParser
class ChatExllamaV2Model(BaseChatModel):
"""exllamav2 models.
To use, you should have the ``exllamav2`` python package installed.
Example:
.. code-block:: python
"""
exllama_config: ExLlamaV2Config
""" ExLlamaV2 config """
exllama_model: ExLlamaV2
""" ExLlamaV2 pretrained Model """
exllama_tokenizer: ExLlamaV2Tokenizer
""" ExLlamaV2 Tokenizer """
exllama_cache: Union[ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4]
""" ExLLamaV2 Cache """
exllama_draft_model: Union[ExLlamaV2, None] = None
""" ExLlamaV2 draft pretrained Model """
exllama_draft_cache: Union[
ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, None
] = None
""" ExLlamaV2 draft Cache """
# メッセージテンプレート
human_message_template: str = "USER: {}\n"
ai_message_template: str = "ASSISTANT: {}"
system_message_template: str = "{}"
do_sample: bool = True
max_new_tokens: int = 64
repetition_penalty: float = 1.1
frequency_penalty: float = 0.0
""" Sampler frequency penalty, default = 0.0 (0 to disable) """
presence_penalty: float = 0.0
""" Sampler presence penalty, default = 0.0 (0 to disable) """
temperature: float = 1
top_k: int = 50
top_p: float = 0.8
top_a: float = 0.0
""" top_p, default = 0.0(0 to disable) """
skew: float = 0.0
""" Skey samplig, default = 0.0(0 to disable) """
typical: float = 0.0
""" Sampler typical threshold, default = 0.0 (0 to disable) """
filters: list[Union[ExLlamaV2TokenEnforcerFilter, ExLlamaV2PrefixFilter]] = []
""" Filters for lm_format_enforcer """
filter_prefer_eos: bool = False
""" Filter prefer eos """
add_bos: bool = True
""" add bos for batch generate token """
seed: int = 1234
""" generator seed """
prompt_line_separator: str = "\n"
@classmethod
def from_model_dir(
cls,
model_dir: str,
draft_model_dir: Union[str, None] = None,
no_draft_scale: bool = True,
cache_8bit: bool = False,
cache_4bit: bool = False,
cache_max_seq_len: int = -1,
batch_size: int = 1,
low_mem=False,
tokenizer_force_json=False,
no_flash_attn: bool = True,
**kwargs: Any,
) -> "ChatExllamaV2Model":
"""Construct the exllamav2 model and tokenzier pipeline object from model_id. ExLlamaV2 version is expected >"""
# Initialize config
config = ExLlamaV2Config()
config.model_dir = model_dir
if low_mem:
config.set_low_mem()
config.prepare()
# Initialize model
model = ExLlamaV2(config)
# Initialize tokenizer
tokenizer = ExLlamaV2Tokenizer(config, force_json=tokenizer_force_json)
# draft model/cache
draft_model = None
draft_cache = None
if draft_model_dir:
print(f" -- Draft model: {draft_model_dir}")
draft_config = ExLlamaV2Config()
draft_config.model_dir = draft_model_dir
draft_config.prepare()
if draft_config.max_seq_len < model.config.max_seq_len:
if no_draft_scale:
print(
f" !! Warning: Draft model native max sequence length is less than sequence length for model. Speed may decrease after {draft_config.max_seq_len} tokens."
)
else:
ratio = model.config.max_seq_len / draft_config.max_seq_len
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
draft_config.scale_alpha_value = alpha
print(f" -- Applying draft model RoPE alpha = {alpha:.4f}")
draft_config.max_seq_len = model.config.max_seq_len
draft_config.no_flash_attn = no_flash_attn
print(" -- Loading draft model...")
draft_model = ExLlamaV2(draft_config)
draft_model.load()
draft_cache = None
if cache_8bit:
draft_cache = ExLlamaV2Cache_8bit(draft_model)
elif cache_4bit:
draft_cache = ExLlamaV2Cache_Q4(draft_model)
else:
draft_cache = ExLlamaV2Cache(draft_model)
# cache
print(" -- Prepare cache...")
cache = None
if cache_8bit:
cache = ExLlamaV2Cache_8bit(
model,
lazy=not model.loaded,
max_seq_len=cache_max_seq_len,
batch_size=batch_size,
)
elif cache_4bit:
cache = ExLlamaV2Cache_Q4(
model,
lazy=not model.loaded,
max_seq_len=cache_max_seq_len,
batch_size=batch_size,
)
else:
cache = ExLlamaV2Cache(
model,
lazy=not model.loaded,
max_seq_len=cache_max_seq_len,
batch_size=batch_size,
)
# load model
if not model.loaded:
print(" -- Loading model...")
model.load_autosplit(cache)
return cls(
exllama_config=config,
exllama_model=model,
exllama_tokenizer=tokenizer,
exllama_cache=cache,
exllama_draft_model=draft_model,
exllama_draft_cache=draft_cache,
**kwargs,
)
@classmethod
def from_model(
cls,
model: "ChatExllamaV2Model",
):
properties = model.__dict__.copy()
exclude_keys = [
"name",
"callbacks",
"callback_manager",
"tags",
"metadata",
"cache",
"exllama_config",
"exllama_model",
"exllama_tokenizer",
"exllama_cache",
"exllama_draft_model",
"exllama_draft_cache",
]
for k in exclude_keys:
properties.pop(k)
return cls(
exllama_config=model.exllama_config,
exllama_model=model.exllama_model,
exllama_tokenizer=model.exllama_tokenizer,
exllama_cache=model.exllama_cache,
exllama_draft_model=model.exllama_draft_model,
exllama_draft_cache=model.exllama_draft_cache,
**properties,
)
def reset_json_schema(self, schema=None):
if schema is None: # Reset
self.filters = []
self.filter_prefer_eos = False
return
schema_parser = JsonSchemaParser(schema)
lmfe_filter = ExLlamaV2TokenEnforcerFilter(
schema_parser, self.exllama_tokenizer
)
prefix_filter = ExLlamaV2PrefixFilter(
self.exllama_model, self.exllama_tokenizer, "{"
) # Make sure we start JSONing right away
self.filters = [lmfe_filter, prefix_filter]
self.filter_prefer_eos = True
@property
def _llm_type(self) -> str:
return "ChatExllamaV2Model"
def _format_message_as_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"{self.prompt_line_separator}{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = self.human_message_template.format(message.content)
elif isinstance(message, AIMessage):
message_text = self.ai_message_template.format(message.content)
elif isinstance(message, SystemMessage):
message_text = self.system_message_template.format(message.content)
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
return self.prompt_line_separator.join(
[self._format_message_as_text(message) for message in messages]
)
def _inference_settings(self):
# Settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = self.temperature
settings.top_k = self.top_k
settings.top_p = self.top_p
settings.top_a = self.top_a
settings.typical = self.typical
settings.skew = self.skew
settings.token_repetition_penalty = self.repetition_penalty
settings.token_frequency_penalty = self.frequency_penalty
settings.token_presence_penalty = self.presence_penalty
settings.filters = self.filters
settings.filter_prefer_eos = self.filter_prefer_eos
return settings
def _generate_base(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
):
generator = ExLlamaV2BaseGenerator(
self.exllama_model, self.exllama_cache, self.exllama_tokenizer
)
return generator
def _generate_streamer(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> ExLlamaV2StreamingGenerator:
prompt = self._format_messages_as_text(messages)
if self.verbose:
print(prompt)
_stop = stop or []
_stop = [_stop] if isinstance(_stop, str) else _stop
# 単一のトークンIDを持つ(特殊)文字列の場合、トークンIDに置き換える
def encode_if_str_is_special_token(s):
token_id = self.exllama_tokenizer.single_id(s)
return token_id if token_id else s
_stop = [encode_if_str_is_special_token(s) for s in _stop]
generator = ExLlamaV2StreamingGenerator(
model=self.exllama_model,
cache=self.exllama_cache,
tokenizer=self.exllama_tokenizer,
draft_model=self.exllama_draft_model,
draft_cache=self.exllama_draft_cache,
num_speculative_tokens=5,
)
settings = self._inference_settings()
# Prompt
# add_bos/add_eosの指定はプロンプトの種類に応じて考慮が必要かも。
input_ids = self.exllama_tokenizer.encode(
prompt, add_bos=True, add_eos=False, encode_special_tokens=True
)
prompt_tokens = input_ids.shape[-1]
generator.set_stop_conditions(_stop + [self.exllama_tokenizer.eos_token_id])
generator.begin_stream_ex(input_ids, settings)
return generator
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = self._format_messages_as_text(messages)
if self.verbose:
print(prompt)
# stop tokenは未指定の場合、tokenizerのeos_token_idを利用。文字列を指定した場合、その1番目のTokenIDを使用。
_stop_token = -1
if stop and isinstance(stop, str):
_stop_token = self.exllama_tokenizer.encode_special(stop)[0]
if self.verbose:
print("Stop token id:", _stop_token)
elif stop and isinstance(stop, list):
_stop_token = self.exllama_tokenizer.encode_special(stop[0])[0]
if self.verbose:
print("Stop token id:", _stop_token)
generator = self._generate_base(messages, stop, **kwargs)
settings = self._inference_settings()
output = generator.generate_simple(
prompt,
settings,
self.max_new_tokens,
seed=self.seed,
add_bos=self.add_bos,
# encode_special_tokens=False,
# decode_special_tokens=False,
stop_token=_stop_token,
completion_only=True,
)
chat_generation = ChatGeneration(message=AIMessage(content=output))
generated_tokens = generator.tokenizer.num_tokens(output)
return ChatResult(
generations=[chat_generation],
llm_output={"completion_tokens": generated_tokens},
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
streamer = self._generate_streamer(messages, stop, **kwargs)
generated_tokens = 0
while True:
res = streamer.stream_ex()
generated_tokens += 1
yield ChatGenerationChunk(message=AIMessageChunk(content=res["chunk"]))
if run_manager:
run_manager.on_llm_new_token(
res["chunk"],
verbose=self.verbose,
)
if res["eos"] or generated_tokens == self.max_new_tokens:
break
async def _astream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
streamer = self._generate_streamer(messages, stop, **kwargs)
generated_tokens = 0
while True:
res = streamer.stream_ex()
generated_tokens += 1
yield ChatGenerationChunk(message=AIMessageChunk(content=res["chunk"]))
if run_manager:
await run_manager.on_llm_new_token(
res["chunk"],
verbose=self.verbose,
)
await asyncio.sleep(0)
if res["eos"] or generated_tokens == self.max_new_tokens:
break
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"config": self.exllama_config,
"model": self.exllama_model,
"tokenizer": self.exllama_tokenizer,
}
このクラスで重要なのはreset_json_schema
メソッドです。
LM Format EnforcerはExLlamaV2との統合モジュールを提供しており、ExLlamaV2TokenEnforcerFilter
クラスというフィルタを使うことでJSONスキーマ等を指定することで出力フォーマットを強制することができます。
from exllamav2.generator.filters import ExLlamaV2PrefixFilter
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer import JsonSchemaParser
def reset_json_schema(self, schema=None):
if schema is None: # Reset
self.filters = []
self.filter_prefer_eos = False
return
schema_parser = JsonSchemaParser(schema)
lmfe_filter = ExLlamaV2TokenEnforcerFilter(
schema_parser, self.exllama_tokenizer
)
prefix_filter = ExLlamaV2PrefixFilter(
self.exllama_model, self.exllama_tokenizer, "{"
) # Make sure we start JSONing right away
self.filters = [lmfe_filter, prefix_filter]
self.filter_prefer_eos = True
あとは、作成したフィルタをExLlamaV2のSamplerに設定することで適用できます。
# Settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = self.temperature
# 中略
settings.filters = self.filters
settings.filter_prefer_eos = self.filter_prefer_eos
Step3. モデルのロード
ExLlamaV2でモデルをロードします。
今回は事前ダウンロードしておいた以下のモデルを利用しました。
from exllamav2_chat import ChatExllamaV2Model
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-0106-GPTQ/"
chat_model = ChatExllamaV2Model.from_model_dir(
model_path,
cache_max_seq_len=4096,
system_message_template="{}<|end_of_turn|>",
human_message_template="GPT4 Correct User: {}<|end_of_turn|>GPT4 Correct Assistant: ",
ai_message_template="GPT4 Correct Assistant: {}",
temperature=0,
top_p=0.0001,
max_new_tokens=512,
repetition_penalty = 1.15,
low_memory=True,
cache_4bit=True,
)
# OpenChat-3.5-0106のtokenizerファイルのバグ対応
chat_model.exllama_tokenizer.eos_token_id = 32000
Step4. 推論でJSON出力を確認してみる
では、実際に構造化出力の推論をしてみましょう。
ExLlamaV2のリポジトリ内にサンプルがあったので、そのお題を流用し、Superhero
というクラスを用いて、そのスキーマを使った出力を強制させてみます。
from pydantic import BaseModel, conlist
from typing import Literal
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers.json import JsonOutputParser
class SuperheroAppearance(BaseModel):
title: str
issue_number: int
year: int
class Superhero(BaseModel):
name: str
secret_identity: str
superpowers: conlist(str, max_length=5)
first_appearance: SuperheroAppearance
gender: Literal["male", "female"]
chat_model.reset_json_schema(Superhero.schema())
prompt = "Here is some information about Superman:\n"
chain = chat_model | JsonOutputParser()
chain.invoke([HumanMessage(content=prompt)])
{'name': 'Superman',
'superpowers': ['invulnerability',
'flight',
'super strength',
'freeze breath'],
'secret_identity': 'Clark Kent',
'first_appearance': {'year': 1938,
'issue_number': 28,
'title': 'Action Comics'},
'gender': 'male'}
JSON形式でsupermanの情報が出力できました。
Step5. ツールを選択させてみる
エージェント処理でよく使う、指定されたツール(機能)の中から適したものの選択をさせてみます。
ダミーで2種類のRetrieverツールを作成。
from langchain.tools.retriever import create_retriever_tool
tool1 = create_retriever_tool(
None,
"retriever_about_freren",
"Search and return information about 葬送のフリーレン.Tool input needs search text.",
)
tool2 = create_retriever_tool(
None,
"retriever_about_spark",
"Search and return information about Apache Spark.Tool input needs search text.",
)
tools = [tool1, tool2]
指定したツールのリストから、ツールを選択してJSON形式で結果を返すプロンプトテンプレートを用意。
from langchain_core.prompts import PromptTemplate
def create_function_calling_prompt(tools):
""" Tool選択のためのプロンプトテンプレートを生成 """
tool_names = ", ".join([t.name for t in tools])
tools_desc = "\n".join([t.name + ": " + t.description for t in tools])
prompt = PromptTemplate(
template=("TOOLS\n------\n"
"Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question."
"The tools the human can use are:\n\n"
"{tools_desc}\n\n"
"RESPONSE FORMAT INSTRUCTIONS\n----------------------------\n\n"
"When responding to me, please output a response in one of two formats:\n\n"
"Markdown code snippet formatted in the following schema:\n\n"
"```json\n"
'{{\n'
' "action": string, \\\\ The action to take. Must be one of tools: {tool_names}\n'
' "action_input": string \\\\ The input to the action\n'
'}}\n'
'```\n\n'
"USER\'S INPUT\n--------------------\n"
"Here is the user\'s input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):\n\n"
"{input}"),
input_variables=["input"],
partial_variables={"tool_names":tool_names, "tools_desc":tools_desc},
)
return prompt
出力用のスキーマを定義し、ツール選択のChainを作成して推論を実行。
from pydantic import BaseModel, conlist
from typing import Literal
from langchain_core.output_parsers.json import JsonOutputParser
tool_names = [t.name for t in tools]
class Action(BaseModel):
action: Literal[tuple(tool_names)]
action_input: str
chat_model.reset_json_schema(Action.schema())
prompt = create_function_calling_prompt(tools)
chain = prompt | chat_model | JsonOutputParser()
print(chain.invoke({"input": "フリーレンの声優は誰?"}))
print(chain.invoke({"input": "Sparkって何"}))
{'action': 'retriever_about_freren', 'action_input': 'フリーレンの声優は誰?'}
{'action': 'retriever_about_spark', 'action_input': 'Apache Spark'}
それぞれ適切なツール選択を行い、かつ確実にJSON形式で返ってきています。
ちなみにLM Format Enforcerを使わない場合も、今回のプロンプト+LLMの組み合わせだと高確率でJSON形式の回答が返ってくるのですが、そうならないときもありました。(今回の試験範囲ではLM Format Enforcerは確実にJSON形式で返ってきた)
補足
この記事を準備中に気づいたのですが、構造化された情報の取得についてはLangChainの公式ドキュメント内にも存在していました。(読んでなかった。。。)
また、同じくLangChainで構造化出力に関する機能アップデートも進んでいるようです。
ローカルLLMのカバーがされるかわからないのですが、精読しておこうと思います。
さらに、LM Format Enforcerについてはcommunityパッケージの方でサポート実装がされていますね。
その他、transformersやvLLMなどといった主要なパッケージをLM Format Enforcerはカバーしています。
まとめ
LM Format Enforcerを使って構造化出力を行ってみました。
エージェント処理だけに留まらず、LLMを活用するためには構造化されたインプットとアウトプットは実アプリを考えると様々なユースケースで利用されると思います。
わたしはここCompound AI systemのような構造を考えることが今後必要だと認識しており、その中でもモデル出力の構造化は必須だろうと考えています。
LLM自体も非常に面白いのですが、周辺技術についてもより身に着けていこうと思っています。