7
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【langchain】LLMChainでOpenAI APIが呼ばれるまでの処理の流れを理解する

Last updated at Posted at 2023-08-19

はじめに

langchainにおける抽象化の雰囲気をざっくり理解するため、langchainのLLMChain.runを実行したときに、元のOpenAI APIがどのように呼び出されているのか、処理の流れを追ってみました。

動機

langchainへの苦手意識を減らすためです。
langchainは複雑な処理を短いコードで書くことができて便利なのですが、処理が何重にも抽象化されているのでチュートリアルにちょっと手を加えて改良するのが大変そうな印象があります。
そこで、langchainの基本的な機能であるLLMChainの内部で素のOpenAI APIライブラリがどのように使われているのか、処理の流れを追うことで、langchainへの苦手意識を減らしたいです。
(なおLLMChainはAzure OpenAI Serviceなど他の言語モデルでも使えるのですが、今回はOpenAI APIに絞って処理を追います。)

例えば以下のコードでLLMChain.runが実行できます。

from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate

llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
    ("human", "What is a good name for a company that makes {product}?")
)

from langchain.chains import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)

# Run the chain only specifying the input variable.
print(chain.run("colorful socks"))
Socktastic

一方、類似の処理をopenaiの公式ライブラリで実行すると以下のようになります。

import openai
import dotenv
import os
dotenv.load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

response = openai.ChatCompletion.create(
  model="gpt-3.5-turbo",
  messages = [{"role": "user", "content": "What is a good name for a company that makes colorful socks?"}]
)

print(response["choices"][0]["message"]["content"])
ColorfulStep

LLMChain.runを実行した場合も、最終的には、openai.ChatCompletion.createが実行されているはずですので、そこに至るまでの、処理の流れを追いたいです。

処理の全体感

langchainはOpenAI APIを始めとするLLMのラッパーライブラリです。LLMの実行や関係する処理をchainという単位で記述し、chain同士をつなげることで、より複雑な処理を実現します。

LLMChainはlangchainの基本的なchainの一つです。LLMChainに任意のLLMやプロンプトテンプレートを与えて初期化し、LLMChain.runにプロンプトテンプレートに代入する文字列を与えて実行することで、LLMの生成結果を戻り値として取得できます。

LLMChainに与えるLLMもlangchainのクラスとして抽象化されたクラスを使います。OpenAI API(gpt-3.5-turbo)を使用する場合は、OpenAI APIを抽象化したChatOpenAIというクラスを初期化して、LLMChainに与えます。

ChatOpenAIBaseChatModelという対話用のLLMを抽象化されたクラスを継承しており、BaseChatModelはさらにBaseLanguageModelというクラスを継承しています。その先はpydanticのBaseModelというクラスになるので、langchainの実装という意味では、BaseLanguageModelまで追えばだいたい完了です。

方針

langchainのソースコードを読んでいきます。バージョンはv0.0.268です。

基本的には、抽象度が深くなる方向、つまり、LLMChain、ChatOpenAI、BaseChatModel、BaseLanguageModelの順に処理を追えばよいです。
ただし、ときどき、継承関係にあるクラスにおいて、具体的な処理が継承元ではなく継承先に書かれていることがあるので、その場合は、抽象度が浅い方向に戻ったりします。
またLLMChainの継承元であるChainクラスも必要に応じてみます。

LLMChain

それではLLMChainから見ます。runというメソッドは存在しないので、おそらく継承元のChainクラスで定義されているはずです。

Chain

Chainを読みます。

    def run(self, *args: str, **kwargs: str) -> str:
        """Run the chain as text in, text out or multiple variables, text out."""
        if len(self.output_keys) != 1:
            raise ValueError(
                f"`run` not supported when there is not exactly "
                f"one output key. Got {self.output_keys}."
            )

        if args and not kwargs:
            if len(args) != 1:
                raise ValueError("`run` supports only one positional argument.")
            return self(args[0])[self.output_keys[0]]

        if kwargs and not args:
            return self(kwargs)[self.output_keys[0]]

        raise ValueError(
            f"`run` supported with either positional arguments or keyword arguments"
            f" but not both. Got args: {args} and kwargs: {kwargs}."
        )

return self(kwargs)[self.output_keys[0]]の部分で、selfが関数のように扱われています。
selfつまりChain自身を関数として扱えるのは、Chain.__call__メソッドが定義されているためです。
Chain.__call__を見ます。

    def __call__(
        self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
    ) -> Dict[str, Any]:
        """Run the logic of this chain and add to output if desired.

        Args:
            inputs: Dictionary of inputs, or single input if chain expects
                only one param.
            return_only_outputs: boolean for whether to return only outputs in the
                response. If True, only new keys generated by this chain will be
                returned. If False, both input keys and new keys generated by this
                chain will be returned. Defaults to False.

        """
        inputs = self.prep_inputs(inputs)
        self.callback_manager.on_chain_start(
            {"name": self.__class__.__name__},
            inputs,
            verbose=self.verbose,
        )
        try:
            outputs = self._call(inputs)
        except (KeyboardInterrupt, Exception) as e:
            self.callback_manager.on_chain_error(e, verbose=self.verbose)
            raise e
        self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
        return self.prep_outputs(inputs, outputs, return_only_outputs)

outputs = self._call(inputs)で出力を得ています。
Chain._callを見ます。

    @abstractmethod
    def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
        """Run the logic of this chain and return the output."""

Chain._callは抽象メソッドであり、具体的な処理が書かれていません。
抽象メソッドは継承先、今回の場合はLLMChainで処理が定義されます。

LLMChainに戻ります。

LLMChain(2回目)

LLMChain._callを見ます。

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        response = self.generate([inputs], run_manager=run_manager)
        return self.create_outputs(response)[0]

response = self.generate([inputs], run_manager=run_manager)からLLMChain.generateが実行されていることがわかります。

LLMChain.generateを見ます。

    def generate(
        self,
        input_list: List[Dict[str, Any]],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> LLMResult:
        """Generate LLM result from inputs."""
        prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
        return self.llm.generate_prompt(
            prompts,
            stop,
            callbacks=run_manager.get_child() if run_manager else None,
            **self.llm_kwargs,
        )

self.llm.generate_promptの実行結果をreturnしています。
self.llmLLMChainの宣言時に代入されたLLMの抽象化クラスです。
クラスの冒頭で以下のように宣言されており、BaseLanguageModelクラスであることがわかります。

    llm: BaseLanguageModel

BaseLanguageModel

BaseLanguageModelBaseLanguageModel.generate_promptを見ます。

    @abstractmethod
    def generate_prompt(
        self,
        prompts: List[PromptValue],
        stop: Optional[List[str]] = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Pass a sequence of prompts to the model and return model generations.

        This method should make use of batched calls for models that expose a batched
        API.

        Use this method when you want to:
            1. take advantage of batched calls,
            2. need more output from the model than just the top generated value,
            3. are building chains that are agnostic to the underlying language model
                type (e.g., pure text completion models vs chat models).

        Args:
            prompts: List of PromptValues. A PromptValue is an object that can be
                converted to match the format of any language model (string for pure
                text generation models and BaseMessages for chat models).
            stop: Stop words to use when generating. Model output is cut off at the
                first occurrence of any of these substrings.
            callbacks: Callbacks to pass through. Used for executing additional
                functionality, such as logging or streaming, throughout generation.
            **kwargs: Arbitrary additional keyword arguments. These are usually passed
                to the model provider API call.

        Returns:
            An LLMResult, which contains a list of candidate Generations for each input
                prompt and additional model provider-specific output.
        """

抽象化メソッドのため、具体的な処理は継承先に記述されていると期待されます。

BaseChatModel

そこで継承先のBaseChatModelを見ます。

BaseChatModel.generate_promptを見ます。

    def generate_prompt(
        self,
        prompts: List[PromptValue],
        stop: Optional[List[str]] = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> LLMResult:
        prompt_messages = [p.to_messages() for p in prompts]
        return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)

self.generateの実行結果を返しています。

BaseChatModel.generateを見ます。

    def generate(
        self,
        messages: List[List[BaseMessage]],
        stop: Optional[List[str]] = None,
        callbacks: Callbacks = None,
        *,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Top Level call"""
        params = self._get_invocation_params(stop=stop, **kwargs)
        options = {"stop": stop}

        callback_manager = CallbackManager.configure(
            callbacks,
            self.callbacks,
            self.verbose,
            tags,
            self.tags,
            metadata,
            self.metadata,
        )
        run_managers = callback_manager.on_chat_model_start(
            dumpd(self), messages, invocation_params=params, options=options
        )
        results = []
        for i, m in enumerate(messages):
            try:
                results.append(
                    self._generate_with_cache(
                        m,
                        stop=stop,
                        run_manager=run_managers[i] if run_managers else None,
                        **kwargs,
                    )
                )
            except (KeyboardInterrupt, Exception) as e:
                if run_managers:
                    run_managers[i].on_llm_error(e)
                raise e
        flattened_outputs = [
            LLMResult(generations=[res.generations], llm_output=res.llm_output)
            for res in results
        ]
        llm_output = self._combine_llm_outputs([res.llm_output for res in results])
        generations = [res.generations for res in results]
        output = LLMResult(generations=generations, llm_output=llm_output)
        if run_managers:
            run_infos = []
            for manager, flattened_output in zip(run_managers, flattened_outputs):
                manager.on_llm_end(flattened_output)
                run_infos.append(RunInfo(run_id=manager.run_id))
            output.run = run_infos
        return output

プロンプトの本体であるmessagesがどこに渡されているか探します。
以下でself._generate_with_cacheメソッドに渡されています。

        for i, m in enumerate(messages):
            try:
                results.append(
                    self._generate_with_cache(
                        m,
                        stop=stop,
                        run_manager=run_managers[i] if run_managers else None,
                        **kwargs,
                    )
                )```

`BaseChatModel._generate_with_cache`を見ます

```Python
    def _generate_with_cache(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        new_arg_supported = inspect.signature(self._generate).parameters.get(
            "run_manager"
        )
        disregard_cache = self.cache is not None and not self.cache
        if langchain.llm_cache is None or disregard_cache:
            # This happens when langchain.cache is None, but self.cache is True
            if self.cache is not None and self.cache:
                raise ValueError(
                    "Asked to cache, but no cache found at `langchain.cache`."
                )
            if new_arg_supported:
                return self._generate(
                    messages, stop=stop, run_manager=run_manager, **kwargs
                )
            else:
                return self._generate(messages, stop=stop, **kwargs)
        else:
            llm_string = self._get_llm_string(stop=stop, **kwargs)
            prompt = dumps(messages)
            cache_val = langchain.llm_cache.lookup(prompt, llm_string)
            if isinstance(cache_val, list):
                return ChatResult(generations=cache_val)
            else:
                if new_arg_supported:
                    result = self._generate(
                        messages, stop=stop, run_manager=run_manager, **kwargs
                    )
                else:
                    result = self._generate(messages, stop=stop, **kwargs)
                langchain.llm_cache.update(prompt, llm_string, result.generations)
                return result

self._generateの実行結果がreturnされています。

                return self._generate(
                    messages, stop=stop, run_manager=run_manager, **kwargs
                )

BaseChatModel._generateを見ます。

    @abstractmethod
    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """Top Level call"""

抽象メソッドです。

ChatOpenAI

継承先であるChatOpenAIを見ます。

ChatOpenAI._generateを見ます。

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        stream: Optional[bool] = None,
        **kwargs: Any,
    ) -> ChatResult:
        if stream if stream is not None else self.streaming:
            generation: Optional[ChatGenerationChunk] = None
            for chunk in self._stream(
                messages=messages, stop=stop, run_manager=run_manager, **kwargs
            ):
                if generation is None:
                    generation = chunk
                else:
                    generation += chunk
            assert generation is not None
            return ChatResult(generations=[generation])

        message_dicts, params = self._create_message_dicts(messages, stop)
        params = {**params, **kwargs}
        response = self.completion_with_retry(
            messages=message_dicts, run_manager=run_manager, **params
        )
        return self._create_chat_result(response)

一旦ストリーミングでない分岐に注目すると、以下の部分で、self.completion_with_retryが実行され、整形後、returnされています。

        response = self.completion_with_retry(
            messages=message_dicts, run_manager=run_manager, **params
        )
        return self._create_chat_result(response)

ChatOpenAI.completion_with_retryを見ます。

    def completion_with_retry(
        self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
    ) -> Any:
        """Use tenacity to retry the completion call."""
        retry_decorator = _create_retry_decorator(self, run_manager=run_manager)

        @retry_decorator
        def _completion_with_retry(**kwargs: Any) -> Any:
            return self.client.create(**kwargs)

        return _completion_with_retry(**kwargs)

self.client.create(**kwargs)が呼ばれています。
createはOpenAI APIのpythonライブラリであるopenai.ChatCompletion.createと思われます。
したがって、self.clientopenai.ChatCompletionが代入されていると予想できます。

しかしChatOpenAI.clientは冒頭で初期値Noneで宣言されています。

class ChatOpenAI(BaseChatModel):
    """略"""
    client: Any = None  #: :meta private:

他の箇所を探すと、validate_environmentメソッドの内部でopenaiライブラリ関連の設定が代入されていることがわかります。

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""
        values["openai_api_key"] = get_from_dict_or_env(
            values, "openai_api_key", "OPENAI_API_KEY"
        )
        values["openai_organization"] = get_from_dict_or_env(
            values,
            "openai_organization",
            "OPENAI_ORGANIZATION",
            default="",
        )
        values["openai_api_base"] = get_from_dict_or_env(
            values,
            "openai_api_base",
            "OPENAI_API_BASE",
            default="",
        )
        values["openai_proxy"] = get_from_dict_or_env(
            values,
            "openai_proxy",
            "OPENAI_PROXY",
            default="",
        )
        try:
            import openai

        except ImportError:
            raise ValueError(
                "Could not import openai python package. "
                "Please install it with `pip install openai`."
            )
        try:
            values["client"] = openai.ChatCompletion
        except AttributeError:
            raise ValueError(
                "`openai` has no `ChatCompletion` attribute, this is likely "
                "due to an old version of the openai package. Try upgrading it "
                "with `pip install --upgrade openai`."
            )
        if values["n"] < 1:
            raise ValueError("n must be at least 1.")
        if values["n"] > 1 and values["streaming"]:
            raise ValueError("n must be 1 when streaming.")
        return values

@root_validator()はpydanticのデコレータで、初期化時に実施する独自のフォーマットチェックを定義するためのものです。詳しくないですが、values[変数名]で、クラス変数に値を代入できるようです。
以下の処理で、clientにopenai.ChatCompletionが代入されていることが確認できました。

        try:
            values["client"] = openai.ChatCompletion

これでようやく、LLMChain.runから始まり、openai.ChatCompletion.createが実行されるまでの処理を追うことができました。

おわりに

ざっくりソースコードを読む流れをおさらいすると以下のようになります。

  • LLMChain.run (見つからない)
  • Chain.run
  • Chain.__call__
  • Chain._call
  • LLMChain._call
  • LLMChain._generate
  • BaseLanguageModel.generate_prompt
  • BaseChatModel.generate_prompt
  • BaseChatModel._generate_with_cache
  • BaseChatModel._generate
  • ChatOpenAI._generate
  • ChatOpenAI.client.create
  • ChatOpenAI.validate_environment

実はこれは工程で言うと全体の半分で、本当は、このあと、ChatOpenAI.client.createの戻り値がどのように加工されて、最終的にLLMChain.runでreturnされるのかの処理も追うべきなのですが、力尽きたので、一旦ここまでとします。今後の課題とします。

とりあえずここまででも謎めいていたchainの処理に少し親しみが湧いたので良かったです。

7
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?