導入
LangchainにはCallbackというモジュールがあります。
Chainの動作をロギングする際など有用な仕組です。
しかし、2024/1/5現在、こちらのドキュメントはLCELを使わないコードが示されており、LCELでどのようにするか私自身よく理解できていませんでした。
調べたところ、以下に例が記載されていましたので、自分の備忘も含めて現時点でのLCELにおけるCallback利用方法をまとめます。
また、Callbackを利用することによるChainの動作詳細の確認や、ついでに以下の公式Docに記載のあるChainの構造確認についてもやってみます。
検証はDatabricks on AWS上で実施しました。が、Databricks以外でも同様に実行できると思います。
DBRは14.1ML、クラスタタイプはg4dn.xlargeで確認しています。
Step0. 必要なパッケージのインストール
ノートブックを作成し、必要なパッケージをインストールします。
推論はExllama v2を使いますので、exllamav2
と、Chainの構造確認のためにgrandalf
をインストールしています。
%pip install -U transformers accelerate "exllamav2>=0.0.11" langchain sentencepiece grandalf
dbutils.library.restartPython()
Step1. モデルのロードとChainの作成
モデルをロードし、動作確認用の単純なChainを作成します。
モデルは事前にダウンロード済みの以下モデルを利用しました。
ChatExllamaV2Model
クラスについては、こちらを参照してください。
from exllamav2_chat import ChatExllamaV2Model
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-1210-GPTQ"
chat_model = ChatExllamaV2Model.from_model_dir(
model_path,
system_message_template="{}",
human_message_template="GPT4 User: {}<|end_of_turn|>",
ai_message_template="GPT4 Assistant: {}",
temperature=0.1,
top_p=0.9,
max_new_tokens=512,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
system_template = "You are a helpful AI assistant. Please reply answer in Japanese."
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{query}"),
AIMessagePromptTemplate.from_template(" "),
]
)
chain = {"query": RunnablePassthrough()} | prompt | chat_model | StrOutputParser()
Step2. Callbacksを指定して推論
今回のポイント。
Callbackを指定して推論します。
Callbackとして、StdOutCallbackHandler
をまずは使用してみます。
LCELのChainに対してCallbackを実行する際には、invoke
やstream
メソッド等のconfig
パラメータにRunnableConfig
を与えることでCallbackを設定できます。
以下、例です。
from langchain.callbacks import StdOutCallbackHandler
from langchain_core.runnables import RunnableConfig
# RunnableConfigにcallbacksをキーとしてcallbackオブジェクトのリストを設定
config = RunnableConfig({'callbacks': [StdOutCallbackHandler()]})
# streamやinvokeを呼び出すときに、configパラメータにRunnableConfigのインスタンスを指定する
for s in chain.stream("東京の明日の天気は?", config=config):
print(s, end="", flush=True)
> Entering new RunnableParallel chain...
> Entering new RunnablePassthrough chain...
> Finished chain.
> Finished chain.
> Entering new ChatPromptTemplate chain...
> Finished chain.
> Entering new StrOutputParser chain...
東京の明日の天気について、現在の情報が手元になっていないため、具体的な回答をしかたがえます。天気予報に関する最新の情報は、地域の天気予報機関やオンラインの天気サイトで確認してください。
> Finished chain.
> Finished chain.
ちなみに、RunnableConfig
を使わなくても、以下のように直接辞書型データとして渡しても同様に動作します。
for s in chain.stream("東京の明日の天気は?", config={'callbacks': [StdOutCallbackHandler()]}):
print(s, end="", flush=True)
また、StdOutCallbackHandler
では一部の情報のみ出力されますが、より詳細な情報を確認したい場合はConsoleCallbackHandler
を使用することができます。
from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain_core.runnables import RunnableConfig
config = RunnableConfig({'callbacks': [ConsoleCallbackHandler()]})
chain.invoke("東京の明日の天気は?", config=config)
[chain/start] [1:chain:RunnableSequence] Entering Chain run with input:
{
"input": "東京の明日の天気は?"
}
[chain/start] [1:chain:RunnableSequence > 2:chain:RunnableParallel<query>] Entering Chain run with input:
{
"input": "東京の明日の天気は?"
}
[chain/start] [1:chain:RunnableSequence > 2:chain:RunnableParallel<query> > 3:chain:RunnablePassthrough] Entering Chain run with input:
-- 中略 --
[chain/end] [1:chain:RunnableSequence > 6:parser:StrOutputParser] [0ms] Exiting Parser run with output:
{
"output": "東京の明日の天気について、現在の情報が持っていないため、具体的な回答をお約束できません。しかし、あなたが知りたいと思っている天気情報を取得するために、以下の方法があります。\n\n1. 地域の天気予報サイト(例えば、JMAのウェブサイト)をチェックする\n2. スマートフォン上の天気アプリを使用して天気情報を確認する\n3. 地域のテレビやラジオで天気予報を聞く\n\nこれらの方法を試して、東京の明日の天気情報を取得してみてください。"
}
[chain/end] [1:chain:RunnableSequence] [4.17s] Exiting Chain run with output:
{
"output": "東京の明日の天気について、現在の情報が持っていないため、具体的な回答をお約束できません。しかし、あなたが知りたいと思っている天気情報を取得するために、以下の方法があります。\n\n1. 地域の天気予報サイト(例えば、JMAのウェブサイト)をチェックする\n2. スマートフォン上の天気アプリを使用して天気情報を確認する\n3. 地域のテレビやラジオで天気予報を聞く\n\nこれらの方法を試して、東京の明日の天気情報を取得してみてください。"
}
Chainの持つメソッドastream_log
を利用することで再細粒度の詳細なログ情報を得ることができますが、ConsoleCallbackHandler
を使うことで、割と丁度いい程度のログ情報を出力することができます。
Step3. Chainの構造を確認
Chainの実行ログを確認することも大事ですが、複雑なChainを構築した場合、実行前に構造を把握したいときがあります。
LCELのChainでは、構造を確認するためのメソッドget_graph
が提供されています。
chain.get_graph()
Graph(nodes={'17436a52d60f493889ec5c837d3556a1': Node(id='17436a52d60f493889ec5c837d3556a1', data=<class 'pydantic.main.RunnableParallel<query>Input'>), 'a1aef6173dc24c439c8bc8595bc1ded6': Node(id='a1aef6173dc24c439c8bc8595bc1ded6', data=RunnablePassthrough()), '015bcdbe13124216b10a4becda4190d6': Node(id='015bcdbe13124216b10a4becda4190d6', data=ChatPromptTemplate(input_variables=['query'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template='You are a helpful AI assistant. Please reply answer in Japanese.')), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['query'], template='{query}')), AIMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], template=' '))])), '62f85eefc0f34af980b009bea2464f67': Node(id='62f85eefc0f34af980b009bea2464f67', data=ChatExllamaV2Model(exllama_config=<exllamav2.config.ExLlamaV2Config object at 0x7f281451bdc0>, exllama_model=<exllamav2.model.ExLlamaV2 object at 0x7f27e0fdba60>, exllama_tokenizer=<exllamav2.tokenizer.ExLlamaV2Tokenizer object at 0x7f274047d060>, exllama_cache=<exllamav2.cache.ExLlamaV2Cache object at 0x7f274047c550>, human_message_template='GPT4 User: {}<|end_of_turn|>', ai_message_template='GPT4 Assistant: {}', max_new_tokens=512, temperature=0.1, top_p=0.9)), 'c63b4c02ef6f485f9c618ccf9189af51': Node(id='c63b4c02ef6f485f9c618ccf9189af51', data=StrOutputParser()), 'a95cb54f78dd4bbe917d8da821919d0b': Node(id='a95cb54f78dd4bbe917d8da821919d0b', data=<class 'pydantic.main.StrOutputParserOutput'>)}, edges=[Edge(source='17436a52d60f493889ec5c837d3556a1', target='a1aef6173dc24c439c8bc8595bc1ded6'), Edge(source='a1aef6173dc24c439c8bc8595bc1ded6', target='015bcdbe13124216b10a4becda4190d6'), Edge(source='015bcdbe13124216b10a4becda4190d6', target='62f85eefc0f34af980b009bea2464f67'), Edge(source='c63b4c02ef6f485f9c618ccf9189af51', target='a95cb54f78dd4bbe917d8da821919d0b'), Edge(source='62f85eefc0f34af980b009bea2464f67', target='c63b4c02ef6f485f9c618ccf9189af51')])
得られるGraphオブジェクトから、もう少し人間に優しい可視化をすることもできます。
chain.get_graph().print_ascii()
+----------------------+
| Parallel<query>Input |
+----------------------+
*
*
*
+-------------+
| Passthrough |
+-------------+
*
*
*
+--------------------+
| ChatPromptTemplate |
+--------------------+
*
*
*
+--------------------+
| ChatExllamaV2Model |
+--------------------+
*
*
*
+-----------------+
| StrOutputParser |
+-----------------+
*
*
*
+-----------------------+
| StrOutputParserOutput |
+-----------------------+
今回はシンプルなChainのため、直線的な動作となりますが、複数のパラメータ利用や複数Chainの連結など、複雑なChainの構造を把握するにはこれらの機能を利用するのがよいと思います。
まとめ
LangchainのLCELにおいて、Callbackを指定する方法を実践してみました。
Langchainにおいて、すっかりLCELでChainを書くのが当たり前になってきました。
まだ動作感がよくわからないところも残ってはいますが、より習熟してうまく使いこなしていきたいと思います。