LoginSignup
2
0

LM Format EnforcerでローカルLLMに構造的なテキスト出力をさせる on Databricks

Last updated at Posted at 2024-03-09

Claude3とかプロプライエタリなサービスが盛り上がっている中でローカルLLMの記事を書くという逆張り。。。

導入

不勉強で初めて知ったのですが、LLMに構造的な出力(例:JSONスキーマに則ったフォーマット出力)をさせるためのパッケージとしてLM Format Enforcerというものが出ていました。

LLM Agentを構築するためにはOpenAI Function callingなどを用いて実行ツールの選択と実行を行うと思いますが、ローカルLLMにおいてこれを実行するためにはいくつかの工夫が必要です。

LangChainでは様々なエージェント処理をカバーしており、例えばJSON出力するChatAgentの例が掲載されています。

ただ、性能の低いLLMを使うとそもそもJSON出力が安定せず、エラーハンドリングやプロンプトの工夫などが必要でした。

最近は以下の記事のようにSGLangのJSONデーコード機能を使って出力形式を強制する実装をしていました。

しかし、SGLangというパッケージに推論部含めて縛られること、SGLang自体発展途上のパッケージであるなどの理由から別種の手段が無いかなと考えていました。
Outlinesguidanceの存在は知っていたのですが、ちょっと使いづらいなと・・・)

その中で、LM Format Enforcerは多くのパッケージとの連携など使いやすそうだったので、今回試してみます。

実装・検証はDatabricks on AWSで行いました。
DBRは14.3ML、クラスタタイプをg4dn.xlargeにしています。

Step1. パッケージインストール

ノートブックを作成し、LLM Format Enforcerをインストール。
また、今回はExLlamaV2を使ってローカルLLMの推論を実施させますので、関連パッケージをインストールします。

%pip install -U -qq transformers accelerate "exllamav2>=0.0.15" langchain lm-format-enforcer pydantic

dbutils.library.restartPython()

Step2. ExLlamaV2のChatModelClassを準備

簡易(?)的にExLlamaV2をLangChainのChatModelとして利用するラッパークラスを作成し、exllamav2_chat.pyという名前でノートブックと同じフォルダに保存。
長いので折り畳み。

exllamav2_chat.py
import asyncio

import itertools

from typing import (
    Any,
    List,
    Union,
    Mapping,
    Tuple,
    Optional,
    Iterator,
    AsyncIterator,
)
from langchain_core.callbacks.manager import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    SystemMessage,
)
from langchain_core.outputs.chat_result import ChatGeneration, ChatResult
from langchain_core.outputs.chat_generation import ChatGenerationChunk

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Cache_8bit,
    ExLlamaV2Cache_Q4,
    ExLlamaV2Tokenizer,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2StreamingGenerator,
    ExLlamaV2Sampler,
)
from exllamav2.generator.filters import ExLlamaV2PrefixFilter

from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer import JsonSchemaParser


class ChatExllamaV2Model(BaseChatModel):
    """exllamav2 models.

    To use, you should have the ``exllamav2`` python package installed.

    Example:
        .. code-block:: python

    """

    exllama_config: ExLlamaV2Config
    """ ExLlamaV2 config """
    exllama_model: ExLlamaV2
    """ ExLlamaV2 pretrained Model """
    exllama_tokenizer: ExLlamaV2Tokenizer
    """ ExLlamaV2 Tokenizer """
    exllama_cache: Union[ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4]
    """ ExLLamaV2 Cache """
    exllama_draft_model: Union[ExLlamaV2, None] = None
    """ ExLlamaV2 draft pretrained Model """
    exllama_draft_cache: Union[
        ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, None
    ] = None
    """ ExLlamaV2 draft Cache """

    # メッセージテンプレート
    human_message_template: str = "USER: {}\n"
    ai_message_template: str = "ASSISTANT: {}"
    system_message_template: str = "{}"

    do_sample: bool = True
    max_new_tokens: int = 64
    repetition_penalty: float = 1.1
    frequency_penalty: float = 0.0
    """ Sampler frequency penalty, default = 0.0 (0 to disable) """
    presence_penalty: float = 0.0
    """ Sampler presence penalty, default = 0.0 (0 to disable) """
    temperature: float = 1
    top_k: int = 50
    top_p: float = 0.8
    top_a: float = 0.0
    """ top_p, default = 0.0(0 to disable) """
    skew: float = 0.0
    """ Skey samplig, default = 0.0(0 to disable) """
    typical: float = 0.0
    """ Sampler typical threshold, default = 0.0 (0 to disable) """

    filters: list[Union[ExLlamaV2TokenEnforcerFilter, ExLlamaV2PrefixFilter]] = []
    """ Filters for lm_format_enforcer """
    filter_prefer_eos: bool = False
    """ Filter prefer eos """

    add_bos: bool = True
    """ add bos for batch generate token """

    seed: int = 1234
    """ generator seed """

    prompt_line_separator: str = "\n"

    @classmethod
    def from_model_dir(
        cls,
        model_dir: str,
        draft_model_dir: Union[str, None] = None,
        no_draft_scale: bool = True,
        cache_8bit: bool = False,
        cache_4bit: bool = False,
        cache_max_seq_len: int = -1,
        batch_size: int = 1,
        low_mem=False,
        tokenizer_force_json=False,
        no_flash_attn: bool = True,
        **kwargs: Any,
    ) -> "ChatExllamaV2Model":
        """Construct the exllamav2 model and tokenzier pipeline object from model_id. ExLlamaV2 version is expected >"""

        # Initialize config
        config = ExLlamaV2Config()
        config.model_dir = model_dir
        if low_mem:
            config.set_low_mem()
        config.prepare()

        # Initialize model
        model = ExLlamaV2(config)

        # Initialize tokenizer
        tokenizer = ExLlamaV2Tokenizer(config, force_json=tokenizer_force_json)

        # draft model/cache
        draft_model = None
        draft_cache = None

        if draft_model_dir:

            print(f" -- Draft model: {draft_model_dir}")

            draft_config = ExLlamaV2Config()
            draft_config.model_dir = draft_model_dir
            draft_config.prepare()

            if draft_config.max_seq_len < model.config.max_seq_len:

                if no_draft_scale:
                    print(
                        f" !! Warning: Draft model native max sequence length is less than sequence length for model. Speed may decrease after {draft_config.max_seq_len} tokens."
                    )
                else:
                    ratio = model.config.max_seq_len / draft_config.max_seq_len
                    alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
                    draft_config.scale_alpha_value = alpha
                    print(f" -- Applying draft model RoPE alpha = {alpha:.4f}")

            draft_config.max_seq_len = model.config.max_seq_len
            draft_config.no_flash_attn = no_flash_attn

            print(" -- Loading draft model...")

            draft_model = ExLlamaV2(draft_config)
            draft_model.load()

            draft_cache = None
            if cache_8bit:
                draft_cache = ExLlamaV2Cache_8bit(draft_model)
            elif cache_4bit:
                draft_cache = ExLlamaV2Cache_Q4(draft_model)
            else:
                draft_cache = ExLlamaV2Cache(draft_model)

        # cache
        print(" -- Prepare cache...")
        cache = None
        if cache_8bit:
            cache = ExLlamaV2Cache_8bit(
                model,
                lazy=not model.loaded,
                max_seq_len=cache_max_seq_len,
                batch_size=batch_size,
            )
        elif cache_4bit:
            cache = ExLlamaV2Cache_Q4(
                model,
                lazy=not model.loaded,
                max_seq_len=cache_max_seq_len,
                batch_size=batch_size,
            )
        else:
            cache = ExLlamaV2Cache(
                model,
                lazy=not model.loaded,
                max_seq_len=cache_max_seq_len,
                batch_size=batch_size,
            )

        # load model
        if not model.loaded:

            print(" -- Loading model...")
            model.load_autosplit(cache)

        return cls(
            exllama_config=config,
            exllama_model=model,
            exllama_tokenizer=tokenizer,
            exllama_cache=cache,
            exllama_draft_model=draft_model,
            exllama_draft_cache=draft_cache,
            **kwargs,
        )

    @classmethod
    def from_model(
        cls,
        model: "ChatExllamaV2Model",
    ):

        properties = model.__dict__.copy()

        exclude_keys = [
            "name",
            "callbacks",
            "callback_manager",
            "tags",
            "metadata",
            "cache",
            "exllama_config",
            "exllama_model",
            "exllama_tokenizer",
            "exllama_cache",
            "exllama_draft_model",
            "exllama_draft_cache",
        ]

        for k in exclude_keys:
            properties.pop(k)

        return cls(
            exllama_config=model.exllama_config,
            exllama_model=model.exllama_model,
            exllama_tokenizer=model.exllama_tokenizer,
            exllama_cache=model.exllama_cache,
            exllama_draft_model=model.exllama_draft_model,
            exllama_draft_cache=model.exllama_draft_cache,
            **properties,
        )

    def reset_json_schema(self, schema=None):
        if schema is None:  # Reset
            self.filters = []
            self.filter_prefer_eos = False
            return

        schema_parser = JsonSchemaParser(schema)
        lmfe_filter = ExLlamaV2TokenEnforcerFilter(
            schema_parser, self.exllama_tokenizer
        )
        prefix_filter = ExLlamaV2PrefixFilter(
            self.exllama_model, self.exllama_tokenizer, "{"
        )  # Make sure we start JSONing right away

        self.filters = [lmfe_filter, prefix_filter]
        self.filter_prefer_eos = True

    @property
    def _llm_type(self) -> str:
        return "ChatExllamaV2Model"

    def _format_message_as_text(self, message: BaseMessage) -> str:
        if isinstance(message, ChatMessage):
            message_text = f"{self.prompt_line_separator}{message.role.capitalize()}: {message.content}"
        elif isinstance(message, HumanMessage):
            message_text = self.human_message_template.format(message.content)
        elif isinstance(message, AIMessage):
            message_text = self.ai_message_template.format(message.content)
        elif isinstance(message, SystemMessage):
            message_text = self.system_message_template.format(message.content)
        else:
            raise ValueError(f"Got unknown type {message}")
        return message_text

    def _format_messages_as_text(self, messages: List[BaseMessage]) -> str:
        return self.prompt_line_separator.join(
            [self._format_message_as_text(message) for message in messages]
        )

    def _inference_settings(self):
        # Settings
        settings = ExLlamaV2Sampler.Settings()
        settings.temperature = self.temperature
        settings.top_k = self.top_k
        settings.top_p = self.top_p
        settings.top_a = self.top_a
        settings.typical = self.typical
        settings.skew = self.skew
        settings.token_repetition_penalty = self.repetition_penalty
        settings.token_frequency_penalty = self.frequency_penalty
        settings.token_presence_penalty = self.presence_penalty

        settings.filters = self.filters
        settings.filter_prefer_eos = self.filter_prefer_eos

        return settings

    def _generate_base(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        **kwargs: Any,
    ):

        generator = ExLlamaV2BaseGenerator(
            self.exllama_model, self.exllama_cache, self.exllama_tokenizer
        )
        
        return generator

    def _generate_streamer(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> ExLlamaV2StreamingGenerator:

        prompt = self._format_messages_as_text(messages)
        if self.verbose:
            print(prompt)

        _stop = stop or []
        _stop = [_stop] if isinstance(_stop, str) else _stop
        # 単一のトークンIDを持つ(特殊)文字列の場合、トークンIDに置き換える
        def encode_if_str_is_special_token(s):
            token_id = self.exllama_tokenizer.single_id(s)
            return token_id if token_id else s
        
        _stop = [encode_if_str_is_special_token(s) for s in _stop]

        generator = ExLlamaV2StreamingGenerator(
            model=self.exllama_model,
            cache=self.exllama_cache,
            tokenizer=self.exllama_tokenizer,
            draft_model=self.exllama_draft_model,
            draft_cache=self.exllama_draft_cache,
            num_speculative_tokens=5,
        )

        settings = self._inference_settings()

        # Prompt
        # add_bos/add_eosの指定はプロンプトの種類に応じて考慮が必要かも。
        input_ids = self.exllama_tokenizer.encode(
            prompt, add_bos=True, add_eos=False, encode_special_tokens=True
        )

        prompt_tokens = input_ids.shape[-1]

        generator.set_stop_conditions(_stop + [self.exllama_tokenizer.eos_token_id])
        generator.begin_stream_ex(input_ids, settings)

        return generator

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:

        prompt = self._format_messages_as_text(messages)
        if self.verbose:
            print(prompt)

        # stop tokenは未指定の場合、tokenizerのeos_token_idを利用。文字列を指定した場合、その1番目のTokenIDを使用。
        _stop_token = -1
        if stop and isinstance(stop, str):
            _stop_token = self.exllama_tokenizer.encode_special(stop)[0]
            if self.verbose:
                print("Stop token id:", _stop_token)
        elif stop and isinstance(stop, list):
            _stop_token = self.exllama_tokenizer.encode_special(stop[0])[0]
            if self.verbose:
                print("Stop token id:", _stop_token)


        generator = self._generate_base(messages, stop, **kwargs)

        settings = self._inference_settings()
        output = generator.generate_simple(
            prompt,
            settings,
            self.max_new_tokens,
            seed=self.seed,
            add_bos=self.add_bos,
            # encode_special_tokens=False,
            # decode_special_tokens=False,
            stop_token=_stop_token,
            completion_only=True,
        )

        chat_generation = ChatGeneration(message=AIMessage(content=output))
        generated_tokens = generator.tokenizer.num_tokens(output)
        return ChatResult(
            generations=[chat_generation],
            llm_output={"completion_tokens": generated_tokens},
        )

    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:

        streamer = self._generate_streamer(messages, stop, **kwargs)
        generated_tokens = 0

        while True:
            res = streamer.stream_ex()
            generated_tokens += 1
            yield ChatGenerationChunk(message=AIMessageChunk(content=res["chunk"]))
            if run_manager:
                run_manager.on_llm_new_token(
                    res["chunk"],
                    verbose=self.verbose,
                )
            if res["eos"] or generated_tokens == self.max_new_tokens:
                break

    async def _astream(
        self,
        messages: List[BaseMessage],
        stop: Union[List[str], None] = None,
        run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
        **kwargs: Any,
    ) -> AsyncIterator[ChatGenerationChunk]:

        streamer = self._generate_streamer(messages, stop, **kwargs)
        generated_tokens = 0

        while True:
            res = streamer.stream_ex()
            generated_tokens += 1
            yield ChatGenerationChunk(message=AIMessageChunk(content=res["chunk"]))
            if run_manager:
                await run_manager.on_llm_new_token(
                    res["chunk"],
                    verbose=self.verbose,
                )
            await asyncio.sleep(0)
            if res["eos"] or generated_tokens == self.max_new_tokens:
                break

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "config": self.exllama_config,
            "model": self.exllama_model,
            "tokenizer": self.exllama_tokenizer,
        }

このクラスで重要なのはreset_json_schemaメソッドです。
LM Format EnforcerはExLlamaV2との統合モジュールを提供しており、ExLlamaV2TokenEnforcerFilterクラスというフィルタを使うことでJSONスキーマ等を指定することで出力フォーマットを強制することができます。

from exllamav2.generator.filters import ExLlamaV2PrefixFilter
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer import JsonSchemaParser

def reset_json_schema(self, schema=None):
    if schema is None:  # Reset
        self.filters = []
        self.filter_prefer_eos = False
        return

    schema_parser = JsonSchemaParser(schema)
    lmfe_filter = ExLlamaV2TokenEnforcerFilter(
        schema_parser, self.exllama_tokenizer
    )
    prefix_filter = ExLlamaV2PrefixFilter(
        self.exllama_model, self.exllama_tokenizer, "{"
    )  # Make sure we start JSONing right away

    self.filters = [lmfe_filter, prefix_filter]
    self.filter_prefer_eos = True

あとは、作成したフィルタをExLlamaV2のSamplerに設定することで適用できます。

# Settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = self.temperature
# 中略
settings.filters = self.filters
settings.filter_prefer_eos = self.filter_prefer_eos

Step3. モデルのロード

ExLlamaV2でモデルをロードします。
今回は事前ダウンロードしておいた以下のモデルを利用しました。

from exllamav2_chat import ChatExllamaV2Model

model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-0106-GPTQ/"

chat_model = ChatExllamaV2Model.from_model_dir(
    model_path,
    cache_max_seq_len=4096,
    system_message_template="{}<|end_of_turn|>",    
    human_message_template="GPT4 Correct User: {}<|end_of_turn|>GPT4 Correct Assistant: ",
    ai_message_template="GPT4 Correct Assistant: {}",
    temperature=0,
    top_p=0.0001,
    max_new_tokens=512,
    repetition_penalty = 1.15,
    low_memory=True,
    cache_4bit=True,
)

# OpenChat-3.5-0106のtokenizerファイルのバグ対応
chat_model.exllama_tokenizer.eos_token_id = 32000

Step4. 推論でJSON出力を確認してみる

では、実際に構造化出力の推論をしてみましょう。
ExLlamaV2のリポジトリ内にサンプルがあったので、そのお題を流用し、Superheroというクラスを用いて、そのスキーマを使った出力を強制させてみます。

from pydantic import BaseModel, conlist
from typing import Literal
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers.json import JsonOutputParser


class SuperheroAppearance(BaseModel):
    title: str
    issue_number: int
    year: int

class Superhero(BaseModel):
    name: str
    secret_identity: str
    superpowers: conlist(str, max_length=5)
    first_appearance: SuperheroAppearance
    gender: Literal["male", "female"]


chat_model.reset_json_schema(Superhero.schema())

prompt = "Here is some information about Superman:\n"

chain = chat_model | JsonOutputParser()
chain.invoke([HumanMessage(content=prompt)])
出力
{'name': 'Superman',
 'superpowers': ['invulnerability',
  'flight',
  'super strength',
  'freeze breath'],
 'secret_identity': 'Clark Kent',
 'first_appearance': {'year': 1938,
  'issue_number': 28,
  'title': 'Action Comics'},
 'gender': 'male'}

JSON形式でsupermanの情報が出力できました。

Step5. ツールを選択させてみる

エージェント処理でよく使う、指定されたツール(機能)の中から適したものの選択をさせてみます。

ダミーで2種類のRetrieverツールを作成。

from langchain.tools.retriever import create_retriever_tool

tool1 = create_retriever_tool(
    None,
    "retriever_about_freren",
    "Search and return information about 葬送のフリーレン.Tool input needs search text.",
)

tool2 = create_retriever_tool(
    None,
    "retriever_about_spark",
    "Search and return information about Apache Spark.Tool input needs search text.",
)

tools = [tool1, tool2]

指定したツールのリストから、ツールを選択してJSON形式で結果を返すプロンプトテンプレートを用意。

from langchain_core.prompts import PromptTemplate

def create_function_calling_prompt(tools):
    """ Tool選択のためのプロンプトテンプレートを生成 """

    tool_names = ", ".join([t.name for t in tools])
    tools_desc = "\n".join([t.name + ": " + t.description for t in tools])

    prompt = PromptTemplate(
        template=("TOOLS\n------\n"
                  "Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question." 
                  "The tools the human can use are:\n\n"
                  "{tools_desc}\n\n"
                  "RESPONSE FORMAT INSTRUCTIONS\n----------------------------\n\n"
                  "When responding to me, please output a response in one of two formats:\n\n"
                  "Markdown code snippet formatted in the following schema:\n\n"
                  "```json\n"
                  '{{\n'
                  '    "action": string, \\\\ The action to take. Must be one of tools: {tool_names}\n'
                  '    "action_input": string \\\\ The input to the action\n'
                  '}}\n'
                  '```\n\n'
                  "USER\'S INPUT\n--------------------\n"
                  "Here is the user\'s input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):\n\n"
                  "{input}"),
        input_variables=["input"],
        partial_variables={"tool_names":tool_names, "tools_desc":tools_desc},
    )

    return prompt

出力用のスキーマを定義し、ツール選択のChainを作成して推論を実行。

from pydantic import BaseModel, conlist
from typing import Literal
from langchain_core.output_parsers.json import JsonOutputParser

tool_names = [t.name for t in tools]

class Action(BaseModel):
    action: Literal[tuple(tool_names)]
    action_input: str

chat_model.reset_json_schema(Action.schema())

prompt = create_function_calling_prompt(tools)
chain = prompt | chat_model | JsonOutputParser()

print(chain.invoke({"input": "フリーレンの声優は誰?"}))
print(chain.invoke({"input": "Sparkって何"}))
出力
{'action': 'retriever_about_freren', 'action_input': 'フリーレンの声優は誰?'}
{'action': 'retriever_about_spark', 'action_input': 'Apache Spark'}

それぞれ適切なツール選択を行い、かつ確実にJSON形式で返ってきています。
ちなみにLM Format Enforcerを使わない場合も、今回のプロンプト+LLMの組み合わせだと高確率でJSON形式の回答が返ってくるのですが、そうならないときもありました。(今回の試験範囲ではLM Format Enforcerは確実にJSON形式で返ってきた)

補足

この記事を準備中に気づいたのですが、構造化された情報の取得についてはLangChainの公式ドキュメント内にも存在していました。(読んでなかった。。。)

また、同じくLangChainで構造化出力に関する機能アップデートも進んでいるようです。
ローカルLLMのカバーがされるかわからないのですが、精読しておこうと思います。

さらに、LM Format Enforcerについてはcommunityパッケージの方でサポート実装がされていますね。

その他、transformersやvLLMなどといった主要なパッケージをLM Format Enforcerはカバーしています。

まとめ

LM Format Enforcerを使って構造化出力を行ってみました。

エージェント処理だけに留まらず、LLMを活用するためには構造化されたインプットとアウトプットは実アプリを考えると様々なユースケースで利用されると思います。

わたしはここCompound AI systemのような構造を考えることが今後必要だと認識しており、その中でもモデル出力の構造化は必須だろうと考えています。
LLM自体も非常に面白いのですが、周辺技術についてもより身に着けていこうと思っています。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0