さっそくやってみた。
導入
前回の下記記事で、Langchain + MLFlowによる推論時のアクションログをトラッキングする方法を試してみました。
とはいえ、これはMLFlowネイティブの機能というわけではなく、少し無理やり感があることは否めません。
(MLFlowユーザとしてはオフィシャルな機能対応を熱望しています)
LLMを利用する際のアクションログトラッキングといえば、Weights and BiasesやLangSmithが有名です。
というわけで、今回はWeights and Biases(WandB)を使って、Langchainのログトラッキングを試してみます。
検証にあたってnpaka大先生の記事を大いに参考にさせてもらいました。
(というか、これを読んだのでやってみようと思っていた)
また、Langchainの該当公式ドキュメントは以下となります。
こちらも十分わかりやすいです。
検証はDatabricks on AWS上で実施しました。
前回同様、DBRは14.1ML、クラスタタイプはg4dn.xlargeを使用しました。
Step0. WandBを利用するための準備
以下の記事を参考にして、WandBのアカウントを取得してください。詳細は割愛。
その上で、APIキーをDatabricksのシークレットに登録しておきます。
今回はDatabricks CLIを使ってスコープをwandb、キーをapi_keyという名前で登録しておきました。
以下のようなコマンドで登録できます。(XXXXXXXXXXXXXXXXXがAPIキー部分)
databricks secrets create-scope wandb
databricks secrets put-secret wandb api_key --string-value XXXXXXXXXXXXXXXXXXXXXXXXXXX
Step1. パッケージインストール&環境変数
Databricks上でノートブックを作成し、必要なパッケージをインストール。
モデルの推論にはExllama V2を利用します。
wandb
パッケージをインストールしているのがいつもとの変化点です。
%pip install -U transformers accelerate "exllamav2>=0.0.11" langchain sentencepiece wandb
dbutils.library.restartPython()
Step2. ロギングするためのChainを準備
モデルをロードし、ロギング動作確認用の単純なChainを作成します。
一部のプロンプトを除いて、前回と同じです。
モデルは事前にダウンロード済みの以下モデルを利用しました。
ChatExllamaV2Model
クラスについては、こちらを参照してください。
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from exllamav2_chat import ChatExllamaV2Model
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-1210-GPTQ"
chat_model = ChatExllamaV2Model.from_model_dir(
model_path,
system_message_template="{}",
human_message_template="GPT4 User: {}<|end_of_turn|>",
ai_message_template="GPT4 Assistant: {}",
temperature=0.1,
top_p=0.3,
max_new_tokens=512,
)
system_template = "You are a helpful AI assistant. Please reply answer in Japanese."
template = """Answer the following question.
Question: {question}
"""
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template(template),
AIMessagePromptTemplate.from_template(" "),
]
)
chain = ({"question": RunnablePassthrough()}
| prompt
| chat_model
| StrOutputParser())
Step2. 推論とトレース実行
まず、環境変数でWandBのトレースを有効化します。
import os
# WandBのトレースを有効化
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
# WANDB_PROJECTを設定することで、WandB上のプロジェクト名を指定できる
os.environ["WANDB_PROJECT"] = "langchain-tracing"
# 後ほど利用
from langchain.callbacks import wandb_tracing_enabled
事前に取得したAPIキーを使ってWandBにログイン。
WANDB_API_KEY
環境変数にキーを設定するだけでも大丈夫かと思われます。(未検証)
import wandb
api_key = dbutils.secrets.get("wandb", "api_key")
wandb.login(key=api_key)
では、試しに通常通り推論を実行してみましょう。
# 推論実行
async for s in chain.astream("Databricksとは何?"):
print(s, end="", flush=True)
wandb: Streaming LangChain activity to W&B at https://wandb.ai/xxxxxxxxx/langchain-tracing/runs/adaeb4iw
wandb: `WandbTracer` is currently in beta.
wandb: Please report any issues to https://github.com/wandb/wandb/issues with the tag `langchain`.
データブリックス(Databricks)は、アメリカのデータエンジニアリング・コーポレーションであり、大規模なデータ処理や分析を行うためのクラウド上のプラットフォームを提供しています。その主要な技術は、Apache Sparkというオープンソースのデータ処理エンジンに基づいています。データブリックスは、機械学習、ビッグデータ分析、AI開発など様々な分野で役立てられることが期待されています。
wandb: WARNING WARNING: Failed to serialize model: Object of type 'ExLlamaV2Config' is not JSON serializable
推論結果と合わせて、WandBへ連携している旨のメッセージやWARNING等が表示されました。
WandBのホームへ行くと、langchain-tracing
という名前でプロジェクトができており、以下のようにログが記録されているのを確認できます。
非常に簡単にトレースできました。
また、推論速度にもほとんど影響が無さそうです。
Step3. トレース範囲の限定
これでLangchainを使って推論するだけでWandBにログが簡単に記録できるようになりましたが、状況によってはログの記録範囲を限定したいときがあります。
単純にはLANGCHAIN_WANDB_TRACING
環境変数を以下のように切り替えればいいのですが、明示的にトレース範囲を指定したい場合などは面倒です。
# LANGCHAIN_WANDB_TRACING環境変数が登録されていたら、削除する
if "LANGCHAIN_WANDB_TRACING" in os.environ:
del os.environ["LANGCHAIN_WANDB_TRACING"]
# WandBにロギングされない
async for s in chain.astream("東京の明日の天気は?"):
print(s, end="", flush=True)
この場合、以下のようにwandb_tracing_enabled
を利用することで、範囲を明示してログを記録することができます。
# LANGCHAIN_WANDB_TRACING環境変数が未設定の場合、wandb_tracing_enabled context managerの範囲内でのみトレースされる。
with wandb_tracing_enabled():
async for s in chain.astream("東京の明日の天気は?"):
print(s, end="", flush=True)
# これは記録されない。
async for s in chain.astream("大阪の明日の天気は?"):
print(s, end="", flush=True)
非常に便利ですね!
まとめ
Langchain + WandBでのログトラッキング(トレース)を試してみました。
LangchainとWandBの連携は、他にCallbackを使うやり方もあるようなのですが、こちらは既に非推奨になっています。
今回、初めてWandBを利用してみたのですが、非常に便利ですね。
運用時の監視も含めて、非常に有用なサービスだと感じました。
いつになるかわかりませんが、今後はLangSmithも試してみたいと思います。