LangSmithやWandBも試していきたい。
導入
Langchainのドキュメントを見ていて、MLflowにログトラッキングするらしい機能があることに気づきました。
丁度Weights and Biasesのような機能がMLflowにもあったらなーと思っていた&ドキュメント見ても挙動がよくわからなかったので試してみます。
MLFlowって何?という方は以下の記事を参照ください。
検証はDatabricks on AWS上で実施しました。
自前でMLFlowのserver類をホスティングしている場合は以下のコードがそのまま動かず、追加の設定が必要かもしれません。
DBRは14.1ML、クラスタタイプはg4dn.xlargeを使用しました。
Step0. パッケージインストール
必要なパッケージをインストール。
モデルの推論にはExllama V2を利用します。
また、MLflowのトラッキングに必要なパッケージとして、mlflow
、textstat
、spacy
の最新版をインストールします。
加えて、spacy
で利用するモデル:en_core_web_smをwheelを使ってインストールしています。
(wheelを使わなくても通常のダウンロード方法で実施できると思います。私の環境都合上、今回はwheelファイルを使ってインストールしました)
%pip install -U transformers accelerate "exllamav2>=0.0.11" langchain sentencepiece
%pip install -U mlflow textstat spacy
%pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl
dbutils.library.restartPython()
Step1. ロギングするためのChainを準備
モデルをロードし、ロギング動作確認用の単純なChainを作成します。
今回はなるべく英語で回答を得るようなプロンプトにしました。理由は後述。
モデルは事前にダウンロード済みの以下モデルを利用しました。
ChatExllamaV2Model
クラスについては、こちらを参照してください。
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from exllamav2_chat import ChatExllamaV2Model
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-1210-GPTQ"
chat_model = ChatExllamaV2Model.from_model_dir(
model_path,
system_message_template="{}",
human_message_template="GPT4 User: {}<|end_of_turn|>",
ai_message_template="GPT4 Assistant: {}",
temperature=0.1,
top_p=0.3,
max_new_tokens=512,
)
system_template = "You are a helpful AI assistant."
template = """Answer the following question in English.
Question: {question}
"""
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template(template),
AIMessagePromptTemplate.from_template(" "),
]
)
chain = ({"question": RunnablePassthrough()}
| prompt
| chat_model
| StrOutputParser())
Step2. 推論とMLFlowによるトラッキング実行
今回のポイント。
MLFlowによるログトラッキングには、MlflowCallbackHandler
というcallback用のクラスを利用するようなので、インスタンスを作成します。
LCELでコールバックを指定するために、RunnableConfig
でコールバックインスタンスをラップします。
from langchain.callbacks import MlflowCallbackHandler
from langchain_core.runnables import RunnableConfig
mlflow_callback = MlflowCallbackHandler()
config = RunnableConfig({'callbacks': [mlflow_callback]})
作成したconfigを用いて、推論を実施。
# 推論実行
async for s in chain.astream("Databricksとは何?", config=config):
print(s, end="", flush=True)
エラーメッセージが出ますが、推論結果が出力されます。(以下出力結果の Answer: ...の部分)
また、ロギングしている影響で、非常にゆっくり個々のトークンが出力されます。
Error in MlflowCallbackHandler.on_chain_end callback: AttributeError("'str' object has no attribute 'items'")
Answer: Databricks is a cloud-based unified analytics platform that provides a fast, collaborative environment for data engineers, data scientists, and business analysts to create, deploy, and manage machine learning models and applications. It is built on top of Apache Spark and offers a range of tools and services for data processing, machine learning, and artificial intelligence. Databricks simplifies the process of working with large datasets and enables users to quickly analyze and derive insights from their data.
Error in MlflowCallbackHandler.on_chain_end callback: AttributeError("'str' object has no attribute 'items'")
Error in MlflowCallbackHandler.on_chain_end callback: AttributeError("'str' object has no attribute 'items'")
最後にflush_tracker
を実行。
mlflow_callback.flush_tracker(chain)
ただ、これはエラーになりました。
ローカルLLMを使ったChat Modelの実装問題かな。。。LLMにOpenAIサービス等を使えば正しく動くと思います。
Step3. トラッキング結果の確認
flush_tracker
メソッドの実行はエラーになりましたが、(一部の)トラッキングログ自体はMLFlowのエクスペリメント上に記録されました。
中を見ると、メトリクスにChainのイベントコール数などが記録されています。
(いろいろありすぎて個々の意味合いを理解できてはいませんが、推論結果を分析したメトリクス等も含まれていそう)
また、ログはMLflowアーティファクトに各ログファイルとして保管されていました。
以下のようなjson形式で保管されます。
また、生成した推論結果について、単語間の関連や各種Entityの可視化などを行ったファイルが生成されていました。
このあたり、日本語だと正しく分析されなかったので、今回は意図的に推論結果を英語にして出力しました。
他には、ログが表形式のHTMLで保管されているtable_action_records.html
というファイルもできていました。
あまり見やすくはないですが、ログ履歴確認に利用できそうです。
まとめ
Langchain + MLFlowでのログトラッキングを試してみました。
推論あたりのロギングには結構時間がかかる(今回の推論出力で2分ぐらい)ので、正直実用的とは言い難いのですが、面白い仕組だなと思いました。
MLFlow自体がこのあたりをカバーし、非同期にロギングする機構とかできると面白いですね。
実のところまだ使ったことがないのですが、LangSmithやWandBが同種の機能を提供しているので、こちらも試していければと思っています。