はじめに
以下の記事で、LLMChainでOpenAI APIがどのように呼び出されているかを確認しました。
上記の記事では、関数の呼び出しの流れを確認しただけで、入出力がどのように加工されているかは追えていませんでした。
そこで以下の記事で、ChatOpenAIからOpenAI APIを呼び出すまでに、どのように入出力が加工されているのかを勉強しました。
前回記事で追えていなかった、LLMChainからChatOpenAIの呼び出しまでの入出力の加工を理解することが、今回の焦点です。
具体的には、LLMChainでは以下のように、文字列を入力し、文字列を出力として得ることができます。
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
import dotenv
dotenv.load_dotenv()
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
[("human", "What is a good name for a company that makes {product}?")]
)
from langchain.chains import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
# Run the chain only specifying the input variable.
print(chain.run("colorful socks"))
Socktastic
ChatOpenAIを使うと類似の処理が以下のように書けます。
from langchain.chat_models import ChatOpenAI
chat = ChatOpenAI()
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage
)
input = [HumanMessage(content="What is a good name for a company that makes colorful socks?")]
output = chat(input)
print(output)
content='VibrantSock Co.' additional_kwargs={} example=False
LLMChainは内部でLLM(今回はChatOpenAI)を呼び出しているはずですので、どのように呼ばれているのかを追うのが目的です。
方針
langchainのソースコードを読みます。バージョンはv0.0.268です。
入力
Chain
LLMChain.runからスタートしますが、継承元のChainで定義されているので、そっちを見ます。
Chainのrun関数を見ます。
def run(self, *args: str, **kwargs: str) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0])[self.output_keys[0]]
if kwargs and not args:
return self(kwargs)[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
runには位置引数かキーワード引数を入力可能です。
位置引数とキーワード引数を両方いれるとエラーになります。
位置引数を複数指定した場合もエラーになります。
引数はself.__call__に渡されています。
Chain.__call__を見ます。
def __call__(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
"""
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
outputs = self._call(inputs)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)
inputsはself.prep_inputsに渡されます。prepはprepareだと思います。
Chain.prep_inputsを見ます。
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
inputsの型はdictかAnyで、前者がキーワード引数、後者が位置引数の場合に相当します。
前者の場合は、特に何も変更なく、returnされます。
後者の場合は、if not isinstance(inputs, dict):
の分岐に入ります。やっていることとしては、inputsをキーワード引数の場合と同様に扱うため、キーワード引数に変換しています。
Chain.prep_inputsで、returnされる直前に実行されるChain._validate_inputsの定義も見ておきます。
def _validate_inputs(self, inputs: Dict[str, str]) -> None:
"""Check that all inputs are present."""
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
self.input_keysは初期化時に設定したプロンプトテンプレートの変数のリストです。self.input_keysとinputsの差集合を見ることで、つまりプロンプトテンプレートで定義した変数が全て入力されているかを確認しています。入力されていなければ、エラーになります。逆に、inputsが過剰な場合はスルーされます。
参考にself.input_keysの定義を以下に示しておきます。LLMChainで、定義されています。
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.prompt.input_variables
self.prep_inputsの役割をまとめると以下です。
- runへのinputsが位置引数(not dict)の場合、LLMChainの初期化時に渡されたプロンプトテンプレートの変数名と組み合わせてdictにする。
- dictのinputsがプロンプトテンプレートの変数をすべて含んでいるかチェックする
Chain.prep_inputsを呼んでいた、Chain.__call__に戻ります。
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
outputs = self._call(inputs)
inputsがself._callに渡されます。
Chain._callは抽象メソッドであり、継承先のLLMChainで定義されているので、LLMChainを見ます。
LLMChain
LLMChain._callを見ます。
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
inputsはネスト1つ深いlistになって、self.generateに渡されています。
ネストが1つ深くなるのは、self.generateがバッチ処理にも対応しているためと思われます。
LLMChain.generateを見ます。
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
**self.llm_kwargs,
)
input_listがself.prep_promptsに渡されています。
LLMChain.prep_promptsを見ます。
def prep_prompts(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if len(input_list) == 0:
return [], stop
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
prompts = []
for inputs in input_list:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
return prompts, stop
inputsからプロンプトテンプレートの変数に含まれるものだけをとりだし、self.prompt.format_promptに渡しています。
self.prompt.format_promptを確認します。
self.promptの型はBasePromptTemplateです。BasePrompteTemplate.format_promptは抽象メソッドであり継承先で定義しています。
今回はBasePromptTemplateを継承したBaseChatPromptTemplateを継承したChatPromptTemplateを使用しており、format_promptはBaseChatPromptTemplateで定義されています。
BaseChatPromptTemplate
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""
Format prompt. Should return a PromptValue.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
PromptValue.
"""
messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages)
kwargsはAny型ですが、今回はプロンプトテンプレートのinput_variablesをkeyとするdictです。
kwargsがself.format_messagesに渡されています。BaseChatPromptTemplate.format_messagesは抽象メソッドなので、継承先で定義を確認します。
ChatPromptTemplate
ChatPromptTemplateを見ます。
ChatPromptTemplate.format_messagesを見ます。
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the chat template into a list of finalized messages.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
list of formatted messages
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
result = []
for message_template in self.messages:
if isinstance(message_template, BaseMessage):
result.extend([message_template])
elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
):
rel_params = {
k: v
for k, v in kwargs.items()
if k in message_template.input_variables
}
message = message_template.format_messages(**rel_params)
result.extend(message)
else:
raise ValueError(f"Unexpected input: {message_template}")
return result
まずkwargsがself._merge_partial_and_user_variables(**kwargs)
の結果で上書きされています。
ChatPromptTemplate._merge_partial_and_user_variablesの記述は見当たらないので、継承元を見ます。
BaseChatPromptTemplate (2回め)
BaseChatPromptTemplate._merge_partial_and_user_variablesを見ます。
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
ざっくりとはpartial_variablesとkwargsをjoinする処理です。
joinの前に、self.partial_variablesというdictのvalueがstrであればそのまま、それ以外はcallableとみなしてその引数無しでの実行結果に置き換えるという処理をしています。
self.partial_variablesの型は以下です。
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
default_factory=dict
)
Mappingは辞書ライクな型を意味します。Filed(default_factory=dict)は初期値の設定で、mutableな型を初期値にする場合は、オブジェクトの共有を避けるため、default_factoryとして指定するようです。
つまり初期値には空のdictが設定されています。
参考: https://zenn.dev/enven/articles/8b80ff38461b4ff329aa
partial_variableの役割についてはよく分からなかったので、何らかのdict、初期値は空、くらいの認識で先に進みます。
ChatPromptTepmlate
ChatPromptTemplate.format_messagesに戻ります。
self._merge_partial_and_user_variablesの実行後、self.messagesに対してループ処理をしています。
self.messagesの型はList[MessageLike]です。
MessageLikeはUnion[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]のエイリアスです。つまり、プロンプトかプロンプトテンプレートです。
ChatPromptTemplate.format_messagesのループ処理の中身を見ます。
if isinstance(message_template, BaseMessage):
result.extend([message_template])
elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
):
rel_params = {
k: v
for k, v in kwargs.items()
if k in message_template.input_variables
}
message = message_template.format_messages(**rel_params)
result.extend(message)
else:
raise ValueError(f"Unexpected input: {message_template}")
elifの部分をみると、format_messagesが再帰的に呼ばれています。この再帰はifの部分、つまり、message_templateがBaseMessageになるまで続きます。
変数を置き換える具体的な処理はここには書かれておらず、軽く探した範囲では見つかりませんでした。
最終的にformat_messagesの戻り値としてList[BaseMessage]が得られるということだけざっくり理解して先に進みます。
BaseChatPromptTemplate
ChatPromptTemplate.format_messagesを呼び出していたBaseChatPromptTemplate.format_promptに戻ります。
self.format_messagesの戻り値はChatPromptValue(messages=message)に変換され、戻り値となります。
ChatPromptValueはmessages: List[BaseMessage]を格納するための型で、messagesをstringにしたりBaseMessageにするAPIを持っています。
LLMChain
BaseChatPromptTemplate.format_promptを呼び出していたLLMChain.prep_promptsに戻ります。
ChatPromptValue型の戻り値をself.format_promptから受け取ったあと、それをpromptsというリストに追加し、最終的にpromptsとstopを関数の戻り値としています。
stopはNoneまたはinput_list[0]["stop"]の値です。
LLMChain.prep_promptsを呼び出していたLLMChain.generateに戻ります。
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
**self.llm_kwargs,
)
self.prep_promptsの戻り値はそのままself.llm.generate_promptに渡されています。
self.llmは今回の想定では例えばChatOpenAIでしょうから、ここではChatOpenAI.generate_promptを呼び出していることになります。
先に進む前に、**self.llm_kwargsの由来だけ念の為確認しておきます。
class LLMChain(Chain):
"""略"""
llm_kwargs: dict = Field(default_factory=dict)
llm_kwargsはLLMChainのクラス変数の一つで、LLMChainの初期化時に、辞書型の値を渡せばそれをself.llmでの生成時に渡してくれる仕様です。self.llm側も初期化時にデフォルトのキーワードを引数を設定することができますが、優先度はllm_kwargsのほうが高いです。
BaseChatModel
ChatOpenAI.generate_promptは継承元のBaseChatModelで定義されています。
BaseChatModel.generate_prompt
を見ます。
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
promptsの各要素をp.to_messagesでBaseMessage型に変換し、self.generateに渡しています。
pの型は今回はChatPromptValueなのでChatPromptValueをみます。
class ChatPromptValue(PromptValue):
"""Chat prompt value.
A type of a prompt value that is built from messages.
"""
messages: List[BaseMessage]
"""List of messages."""
def to_string(self) -> str:
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of messages."""
return self.messages
to_messagesはシンプルにクラス変数のmessagesを返します。戻り値の型はList[BaseMessage]です。
BaseChatModel.generate_promptに戻ります。
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
p.to_messagesで取得したList[BaseMessage]型のmessagesをBaseChatModel.generateに渡しています。
この先は、【langchain】ChatOpenAIの入出力処理をざっくり理解するで記載済みなので省略します。
要点は、self.generateの戻り値はLLMResultでLLMResultであり、LLMResult.generations[0][0].messageによって出力メッセージを取り出せるという点です。
出力
BaseChatModel.generate以降の入出力処理は前述の別記事で確認済みのため、self.generateの戻り値としてLLMResultを受け取ったところから処理をおいます。
BaseChatModel
BaseChatModel.generate_promptの中でBaseChatModel.generateを読んでいました。BaseChatModel.generateの戻り値はそのままBaseChatModel.generate_promptの戻り値となっています。
LLMChain
BaseChatModel.generate_promptを呼び出しているLLMChain.generateを確認します。
ここでもBaseChatModel.generate_promptの戻り値がそのままLLMChain.generateの戻り値となっています。
LLMChain.generateを呼んでいるLLMChain._callを見ます。
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
self.generateの戻り値を、self.create_outputsで変換しています。
LLMChain.create_outputsを見ます。
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
"""Create outputs from response."""
result = [
# Get the text of the top generated string.
{
self.output_key: self.output_parser.parse_result(generation),
"full_generation": generation,
}
for generation in llm_result.generations
]
if self.return_final_only:
result = [{self.output_key: r[self.output_key]} for r in result]
return result
LLMResult.generations[i]をself.output_parser.parse_resultで変換しています。
self.output_parserは以下のように定義されています。
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
今回は明示的に指定していないのでデフォルトのStrOutputParserが代入されていると思われます。
StrOutputParserはBaseTransformParserという型を継承しています。
class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""
@property
def lc_serializable(self) -> bool:
"""Whether the class LangChain serializable."""
return True
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
return "default"
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
BaseTransformParserは特に関係する処理がなさそうなので省略します。
BaseTransformParserが継承しているBaseOutputParserを見ます。
def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return self.parse(result[0].text)
BaseOutputParser.parse_resultはList[Gneration]型を引数に取り、self.parse(result[0].text)に渡しています。
BaseMessageで文字列が格納されるプラパティはcontentですが、LLMResultをつくるときにcontentの値をtextプラパティにコピーしているので、result[0].textで生成結果の文字列を取得することができます。self.parseは継承先のStrOutputParserで定義されている通り文字列を加工なしで返す関数です。
よって、LLMChain.output_parser.parse_resultの処理は結局、LLMResultからテキストを取り出して返す、ということになります。
LLMchain.output_parser.parse_resultを呼び出しているLLMChain.create_outputsに戻ります。
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
"""Create outputs from response."""
result = [
# Get the text of the top generated string.
{
self.output_key: self.output_parser.parse_result(generation),
"full_generation": generation,
}
for generation in llm_result.generations
]
if self.return_final_only:
result = [{self.output_key: r[self.output_key]} for r in result]
return result
self.return_final_onlyがTrueの場合は、self.output_keyに対応する値だけが戻ります。
それ以外は、output_key以外のすべてのgenerationの情報が戻ります。
LLMChain.create_responseを呼び出しているLLMChain._callに戻ります。LLMChain._callはLLMChain.create_outputsの戻り値をそのまま戻しています。
Chain
LLMChain._callを呼び出しているChain.__call__をみます。
def __call__(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]:
"""略"""
try:
outputs = self._call(inputs)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)
self._callの戻り値がself.prep_outputsに渡されています。
Chain.prep_outputsを見ます。
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
まずoutputsをself._validate_outputsに渡しています。
def _validate_outputs(self, outputs: Dict[str, str]) -> None:
if set(outputs) != set(self.output_keys):
raise ValueError(
f"Did not get output keys that were expected. "
f"Got: {set(outputs)}. Expected: {set(self.output_keys)}."
)
self._validate_outputsはoutputsのkeyがChainの初期化時に定義されたoutput_keysと同一か確認します。
次に、memoryが存在していれば、self.memory.save_context(inputs, outputs)を実行しています。memoryについては必要があれば別記事で詳細を追うとして今はスキップします。
最後にreturn_only_outputsがTrueであればoutputsだけを、そうでなければ、inputsとoutputsをjoinした辞書を返します。
結局、self.prep_outputsによってkey値の正しさが保証され、必要に応じてinputsとまとめられます。
self.prep_outputsの戻り値はそのままself.__call__の戻り値になります。
Chain.__call__を呼び出しているChain.runをみます。
Chain.runはChain.__call__の戻り値(dict)からoutput_keysの0番目の値を取り出して返しています。output_keysが複数個の場合はあまり想定されていないのかもしれません。
return self(args[0])[self.output_keys[0]]
入出力処理のまとめ
入出力の処理をmessageの型の変化とそれを扱う関数に注目してまとめます。
入力
strの入力が初期化時に定義したプロンプトテンプレートを用いてプロンプト化され、langchainの抽象化されたLLMに渡されます。
- 初期入力:str
- LLMChain.run
- Chain.run
- Chain.__call__
- Chain.prep_inputsで入力を辞書型に変換。keyは初期化時に設定したプロンプトテンプレートの変数名
- dict
- Chain._call
- LLMChain.generate
- LLMChain.prep_prompts
- BaseChatPromptTemplate.format_prompt
- ChatPromptTemplate.format_messages
- PromptTemplateがBaseMessageになるまで再帰
- BaseMessageをChatPromptValueに変換(格納)
- ChatPromptTemplate.format_messages
- BaseChatPromptTemplate.format_prompt
- ChatPromptValue
- LLMChain.BaseChatModel.generate_prompt
- ChatPromptValue.to_messagesでBaseMessageに変換
- LLMChain.BaseChatModel.generate_prompt
- BaseMessage
- LLMChain.BaseChatModel.generate
出力
langchainの抽象化されたLLMから受け取った値から生成結果に相当する文字列を取り出して、辞書に格納して返します。
- 初期出力: LLMResult (LLMChain.BaseChatModel.generateの戻り値からスタート)
- LLMChain.BaseChatModel.generate_prompt
- Chain.generate
- Chain._call
- LLMChain.create_outputs
- ループ処理でLLMResult.generations[i]を取り出す
- LLMResult.generations[i]をself.output_parser.parse_resultに渡し生成結果の文字列を取得し、辞書で返す
- dict
- Chain.__call__
- Chain.prep_outputsで出力をバリデーション、型はdictのまま
- Chain.run
- Chain.__call__
実験
理解の確認のため実際にコードを書いて実験します。
LLMに渡すパラメータの指定
まず普通に実行してみます。
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
import dotenv
dotenv.load_dotenv()
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている会社の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt)
# Run the chain only specifying the input variable.
print(chain.run("カラフルなソックス"))
カラフルなソックスを作っている会社の名前は、多くの会社が存在しますが、代表的な一つとして「Happy Socks」というブランドがあります。
LLMに渡すパラメータはChatOpenAIの宣言時またはChainの宣言時に指定できます。
ChatOpenAIの宣言時に渡す場合は以下です。
llm = ChatOpenAI(temperature=0.9, max_tokens=5)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている会社の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt)
# Run the chain only specifying the input variable.
print(chain.run("カラフルなソックス"))
カラフルな
Chainの宣言時に渡す場合は以下です。
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている会社の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt, llm_kwargs={"max_tokens": 5})
# Run the chain only specifying the input variable.
print(chain.run("カラフルなソックス"))
カラフルな
記述量としては、ChatOpenAIの宣言時に指定するほうが楽ですね。
両方で指定した場合はChainの初期化時の値が優先されます。
llm = ChatOpenAI(temperature=0.9, max_tokens=100)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている会社の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt, llm_kwargs={"max_tokens": 5})
# Run the chain only specifying the input variable.
print(chain.run("カラフルなソックス"))
カラフルな
runのキーワード引数
runにキーワード引数を与えてみます。
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt)
# Run the chain only specifying the input variable.
print(chain.run(product="カラフルなソックス", noun="人物"))
カラフルなソックスを作っている人物の名前は様々です。具体的な名前は分かりませんが、ソックスデザイナー、ファッションデザイナー、アーティスト、クリエイターなどの多くの人々がカラフルなソックスを製作しています。
キーワード引数の代わりに辞書を与えることもできます。これはlangchainというよりpythonの特徴ですが。
print(chain.run({"product":"カラフルなソックス", "noun":"人物"}))
特定の人物を指定しない場合、カラフルなソックスを作る人々は、ソックスデザイナーやソックスメーカーとして知られることがあります。有名なソックスメーカーやデザイナーには、Happy SocksやSTANCE、Paul Smithなどがあります。ただし、ソックスを作る人々はさまざまな場所で活動しており、個々の名前は様々です。
プロンプトテンプレートに対して少ない引数を与えるとエラーになります。
print(chain.run(product="カラフルなソックス"))
ValueError: Missing some input keys: {'noun'}
逆に過剰な場合はエラーなく実行できます。
print(chain.run(product="カラフルなソックス", noun="人物", year="1992"))
その人物の名前は不特定です。カラフルなソックスを作っている人は多くの人々がいますので、具体的な名前を知るためにはその人物についての詳細情報が必要です。
runの位置引数
位置引数は1つのみ渡すことができます。
複数渡すとエラーになります。
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt)
# Run the chain only specifying the input variable.
print(chain.run("カラフルなソックス", "人物"))
ValueError: `run` supports only one positional argument.
よって、プロンプトテンプレートで複数の変数が定義されている場合は位置引数は使えません。
変数なしのプロンプトテンプレートの実行
変数なしのプロンプトテンプレートを使いたい場合は、辞書を与える必要があります
llm = ChatOpenAI(temperature=0.9, max_tokens=5)
prompt = ChatPromptTemplate.from_messages(
("human", "ソックスを作っている会社の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt, return_final_only=False)
# Run the chain only specifying the input variable.
print(chain.run({}))
ソックスを
余分なキーワード引数が無視されることを利用して、ダミーのキーワード引数を与えても良いです。
print(chain.run(dummy=""))
ソックスを
引数なしでrunを実行するとエラーになります。
print(chain.run())
ValueError: `run` supported with either positional arguments or keyword arguments, but none were provided.
辞書以外の位置引数を与えた場合もエラーになります。
print(chain.run(""))
ValueError: A single string input was passed in, but this chain expects multiple inputs (set()). When a chain expects multiple inputs, please call it by passing in a dictionary, eg `chain({'foo': 1, 'bar': 2})
output_keyの確認
chainのoutput_keyを確認してみます。
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt)
print(chain.output_key)
print(chain.output_keys)
text
['text']
output_keyはtext、output_keysは要素がtextだけのリストです。
宣言時にoutput_keyを指定できます(使い所は知りません)。
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt, output_key="sentence")
print(chain.output_key)
print(chain.output_keys)
sentence
['sentence']
return_final_only
宣言時にreturn_final_onlyを追加すると、output_keysにfull_generatoinsが追加されます。
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt, return_final_only=False)
print(chain.output_key)
print(chain.output_keys)
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt, return_final_only=False)
print(chain.output_key)
print(chain.output_keys)
text
['text', 'full_generation']
return_final_onlyがFalseの状態でrunを実行してみます。
直感的にはパースされる前の出力も合わせて取得できることが期待されます。
llm = ChatOpenAI(temperature=0.9)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt, return_final_only=False)
print(chain.run(product="カラフルなソックス", noun="人物"))
カラフルなソックスを作っている人物の具体的な名前はわかりませんが、ソックスデザイナーやソックスメーカーといった職業を持つ人々がカラフルなソックスを作っている可能性があります。
直感に反して、文字列だけが返されています。
runのソースコードではreturn_final_onlyの分岐が存在しないので、当然です。
return_final_onlyの変換を見るには、runではなく__call__を呼ぶ必要があります。
__call__
まず、return_final_only=Trueで__call__を呼んでみます。runがキーワード引数であったのに対し、__call__では辞書で引数を与える必要があることに注意です。ついでに結果をみやすくするため、max_toknes=5としています。
llm = ChatOpenAI(temperature=0.9, max_tokens=5)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt)
# Run the chain only specifying the input variable.
print(chain({"product":"カラフルなソックス", "noun":"人物"}))
{'product': 'カラフルなソックス', 'noun': '人物', 'text': 'カラフルな'}
出力がrunの場合より詳しく得られていることがわかります。
脱線しますが、この状態で、output_keyを設定してみます。
llm = ChatOpenAI(temperature=0.9, max_tokens=5)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt, output_key="sentence")
# Run the chain only specifying the input variable.
print(chain({"product":"カラフルなソックス", "noun":"人物"}))
{'product': 'カラフルなソックス', 'noun': '人物', 'sentence': 'カラフルな'}
生成結果に相当するkeyがtextからsentenceに変わっています。
話を本筋に戻し、return_final_only=Falseで__call__を読んでみます。
llm = ChatOpenAI(temperature=0.9, max_tokens=5)
prompt = ChatPromptTemplate.from_messages(
("human", "{product}を作っている{noun}の名前は?")
)
chain = LLMChain(llm=llm, prompt=prompt, return_final_only=False)
# Run the chain only specifying the input variable.
print(chain({"product":"カラフルなソックス", "noun":"人物"}))
{'product': 'カラフルなソックス', 'noun': '人物', 'text': '作っている人', 'full_generation': [ChatGeneration(text='作っている人', generation_info={'finish_reason': 'length'}, message=AIMessage(content='作っている人', additional_kwargs={}, example=False))]}
full_generationの情報も併せて返ってきました。
function calling
LLMChainでfunction callingをやってみます。
function callingのためのキーワード引数は、ChatOpenAIかLLMChainの初期化時に渡しておけば良いです。今回は、ChatOpenAIの宣言時に渡すことにします。
とりあえずrunを実行してみます。
from langchain.chat_models import ChatOpenAI
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage
)
import json
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
import dotenv
dotenv.load_dotenv()
functions = [
# AIが、質問に対してこの関数を使うかどうか、
# また使う時の引数は何にするかを判断するための情報を与える
{
"name": "Person",
"description": "人物の情報を得る",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "人物の名前",
},
"age": {
"type": "string",
"description": "人物の年齢",
},
},
"required": ["name", "age"],
},
}
]
llm = ChatOpenAI(functions=functions, function_call={"name": "Person"})
prompt = ChatPromptTemplate.from_messages(
("human", "太郎は30歳です")
)
chain = LLMChain(llm=llm, prompt=prompt, return_final_only=False)
print("chain result:", chain.run({}))
WARNING! functions is not default parameter.
functions was transferred to model_kwargs.
Please confirm that functions is what you intended.
WARNING! function_call is not default parameter.
function_call was transferred to model_kwargs.
Please confirm that function_call is what you intended.
chain result:
ChatOpenAIの初期化時にfunctionsとfunction_callはデフォルトのパラメータではないと警告が出ますが、正しく処理されています。
function callingされているので、contentは空文字列となります。したがって、runの出力も空文字列となります。
function callingの出力を得るためには、runではなく__call__を使う必要があります。
chain = LLMChain(llm=llm, prompt=prompt, return_final_only=False)
print("chain result:", chain({}))
chain result: {'text': '', 'full_generation': [ChatGeneration(text='', generation_info={'finish_reason': 'stop'}, message=AIMessage(content='', additional_kwargs={'function_call': {'name': 'Person', 'arguments': '{\n "name": "太郎",\n "age": "30"\n}'}}, example=False))]}
function callの結果だけ取り出してみます。
chain = LLMChain(llm=llm, prompt=prompt, return_final_only=False)
response = chain({})
function_call_redsult = response["full_generation"][0].message.additional_kwargs.get("function_call")
if function_call_redsult:
arguments = function_call_redsult["arguments"]
print(arguments)
{
"name": "太郎",
"age": "30"
}
できました。
おわりに
抽象化が何重にもなされているlangchainへの理解を深めるため、LLMChainの入出力処理をソースコードを読んで追ってみました。
結果、位置引数とキーワード引数の使い分けや、runと__call__の使い分けなど、理解を深めることができました。
今回追ったのはChainの基本的な機能だけなので、機会があれば、以下の部分も今後、処理を追ってみたいです。
- メモリー
- 非同期処理
- 複数Chainの連携
- コールバックマネジャー