はじめに
langchainにおける抽象化の雰囲気をざっくり理解するため、langchainのLLMChain.run
を実行したときに、元のOpenAI APIがどのように呼び出されているのか、処理の流れを追ってみました。
動機
langchainへの苦手意識を減らすためです。
langchainは複雑な処理を短いコードで書くことができて便利なのですが、処理が何重にも抽象化されているのでチュートリアルにちょっと手を加えて改良するのが大変そうな印象があります。
そこで、langchainの基本的な機能であるLLMChainの内部で素のOpenAI APIライブラリがどのように使われているのか、処理の流れを追うことで、langchainへの苦手意識を減らしたいです。
(なおLLMChainはAzure OpenAI Serviceなど他の言語モデルでも使えるのですが、今回はOpenAI APIに絞って処理を追います。)
例えば以下のコードでLLMChain.runが実行できます。
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "What is a good name for a company that makes {product}?")
)
from langchain.chains import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
# Run the chain only specifying the input variable.
print(chain.run("colorful socks"))
Socktastic
一方、類似の処理をopenaiの公式ライブラリで実行すると以下のようになります。
import openai
import dotenv
import os
dotenv.load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages = [{"role": "user", "content": "What is a good name for a company that makes colorful socks?"}]
)
print(response["choices"][0]["message"]["content"])
ColorfulStep
LLMChain.run
を実行した場合も、最終的には、openai.ChatCompletion.create
が実行されているはずですので、そこに至るまでの、処理の流れを追いたいです。
処理の全体感
langchain
はOpenAI APIを始めとするLLMのラッパーライブラリです。LLMの実行や関係する処理をchain
という単位で記述し、chain同士をつなげることで、より複雑な処理を実現します。
LLMChain
はlangchainの基本的なchainの一つです。LLMChainに任意のLLMやプロンプトテンプレートを与えて初期化し、LLMChain.run
にプロンプトテンプレートに代入する文字列を与えて実行することで、LLMの生成結果を戻り値として取得できます。
LLMChainに与えるLLMもlangchainのクラスとして抽象化されたクラスを使います。OpenAI API(gpt-3.5-turbo)を使用する場合は、OpenAI APIを抽象化したChatOpenAI
というクラスを初期化して、LLMChainに与えます。
ChatOpenAI
はBaseChatModel
という対話用のLLMを抽象化されたクラスを継承しており、BaseChatModel
はさらにBaseLanguageModel
というクラスを継承しています。その先はpydanticのBaseModel
というクラスになるので、langchainの実装という意味では、BaseLanguageModel
まで追えばだいたい完了です。
方針
langchainのソースコードを読んでいきます。バージョンはv0.0.268です。
基本的には、抽象度が深くなる方向、つまり、LLMChain、ChatOpenAI、BaseChatModel、BaseLanguageModelの順に処理を追えばよいです。
ただし、ときどき、継承関係にあるクラスにおいて、具体的な処理が継承元ではなく継承先に書かれていることがあるので、その場合は、抽象度が浅い方向に戻ったりします。
またLLMChainの継承元であるChain
クラスも必要に応じてみます。
LLMChain
それではLLMChainから見ます。run
というメソッドは存在しないので、おそらく継承元のChain
クラスで定義されているはずです。
Chain
Chainを読みます。
def run(self, *args: str, **kwargs: str) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0])[self.output_keys[0]]
if kwargs and not args:
return self(kwargs)[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
return self(kwargs)[self.output_keys[0]]
の部分で、selfが関数のように扱われています。
self
つまりChain
自身を関数として扱えるのは、Chain.__call__
メソッドが定義されているためです。
Chain.__call__
を見ます。
def __call__(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
"""
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
outputs = self._call(inputs)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)
outputs = self._call(inputs)
で出力を得ています。
Chain._call
を見ます。
@abstractmethod
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
Chain._call
は抽象メソッドであり、具体的な処理が書かれていません。
抽象メソッドは継承先、今回の場合はLLMChain
で処理が定義されます。
LLMChainに戻ります。
LLMChain(2回目)
LLMChain._call
を見ます。
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
response = self.generate([inputs], run_manager=run_manager)
からLLMChain.generate
が実行されていることがわかります。
LLMChain.generate
を見ます。
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
**self.llm_kwargs,
)
self.llm.generate_prompt
の実行結果をreturnしています。
self.llm
はLLMChain
の宣言時に代入されたLLMの抽象化クラスです。
クラスの冒頭で以下のように宣言されており、BaseLanguageModelクラスであることがわかります。
llm: BaseLanguageModel
BaseLanguageModel
BaseLanguageModelのBaseLanguageModel.generate_prompt
を見ます。
@abstractmethod
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Pass a sequence of prompts to the model and return model generations.
This method should make use of batched calls for models that expose a batched
API.
Use this method when you want to:
1. take advantage of batched calls,
2. need more output from the model than just the top generated value,
3. are building chains that are agnostic to the underlying language model
type (e.g., pure text completion models vs chat models).
Args:
prompts: List of PromptValues. A PromptValue is an object that can be
converted to match the format of any language model (string for pure
text generation models and BaseMessages for chat models).
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
抽象化メソッドのため、具体的な処理は継承先に記述されていると期待されます。
BaseChatModel
そこで継承先のBaseChatModelを見ます。
BaseChatModel.generate_prompt
を見ます。
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
self.generate
の実行結果を返しています。
BaseChatModel.generate
を見ます。
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
run_managers = callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options
)
results = []
for i, m in enumerate(messages):
try:
results.append(
self._generate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
)
except (KeyboardInterrupt, Exception) as e:
if run_managers:
run_managers[i].on_llm_error(e)
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
if run_managers:
run_infos = []
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
run_infos.append(RunInfo(run_id=manager.run_id))
output.run = run_infos
return output
プロンプトの本体であるmessages
がどこに渡されているか探します。
以下でself._generate_with_cache
メソッドに渡されています。
for i, m in enumerate(messages):
try:
results.append(
self._generate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
)```
`BaseChatModel._generate_with_cache`を見ます。
```Python
def _generate_with_cache(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return self._generate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = self._generate(messages, stop=stop, **kwargs)
langchain.llm_cache.update(prompt, llm_string, result.generations)
return result
self._generate
の実行結果がreturnされています。
return self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
BaseChatModel._generate
を見ます。
@abstractmethod
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
抽象メソッドです。
ChatOpenAI
継承先であるChatOpenAIを見ます。
ChatOpenAI._generate
を見ます。
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
if stream if stream is not None else self.streaming:
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
一旦ストリーミングでない分岐に注目すると、以下の部分で、self.completion_with_retry
が実行され、整形後、returnされています。
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
ChatOpenAI.completion_with_retry
を見ます。
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
self.client.create(**kwargs)
が呼ばれています。
create
はOpenAI APIのpythonライブラリであるopenai.ChatCompletion.create
と思われます。
したがって、self.client
にopenai.ChatCompletion
が代入されていると予想できます。
しかしChatOpenAI.client
は冒頭で初期値Noneで宣言されています。
class ChatOpenAI(BaseChatModel):
"""略"""
client: Any = None #: :meta private:
他の箇所を探すと、validate_environment
メソッドの内部でopenaiライブラリ関連の設定が代入されていることがわかります。
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_organization"] = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
default="",
)
values["openai_api_base"] = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
)
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values
@root_validator()
はpydanticのデコレータで、初期化時に実施する独自のフォーマットチェックを定義するためのものです。詳しくないですが、values[変数名]で、クラス変数に値を代入できるようです。
以下の処理で、client
にopenai.ChatCompletionが代入されていることが確認できました。
try:
values["client"] = openai.ChatCompletion
これでようやく、LLMChain.run
から始まり、openai.ChatCompletion.create
が実行されるまでの処理を追うことができました。
おわりに
ざっくりソースコードを読む流れをおさらいすると以下のようになります。
- LLMChain.run (見つからない)
- Chain.run
- Chain.__call__
- Chain._call
- LLMChain._call
- LLMChain._generate
- BaseLanguageModel.generate_prompt
- BaseChatModel.generate_prompt
- BaseChatModel._generate_with_cache
- BaseChatModel._generate
- ChatOpenAI._generate
- ChatOpenAI.client.create
- ChatOpenAI.validate_environment
実はこれは工程で言うと全体の半分で、本当は、このあと、ChatOpenAI.client.create
の戻り値がどのように加工されて、最終的にLLMChain.run
でreturnされるのかの処理も追うべきなのですが、力尽きたので、一旦ここまでとします。今後の課題とします。
とりあえずここまででも謎めいていたchainの処理に少し親しみが湧いたので良かったです。