導入
MLflow 2.14.0が本日リリースされていました。
メジャー機能としてDAIS2024のキーノートでも触れられていたMLflow Tracingがあります。
また、2.14.0リリース前から、Databricksの公式ドキュメントにもTracingに関する文書が掲載されています。
MLflowへのTracing機能はWandBやLangSmithを触っていたときから期待していたものの一つというのもあり、ひとまず軽く試してみました。
試験環境はDatabricks on AWS、DBRは15.2MLです。
MLflow Tracingとは
概念理解としては、以下のMLflow公式ドキュメントを参照するのがいいかと思います。
機械翻訳で一部抜粋すると、
機械学習(ML)の文脈におけるトレースとは、MLモデルの実行中のデータフローと処理ステップの詳細な追跡と記録を指します。これは、データ入力から予測出力まで、モデルの動作の各段階に対する透明性と洞察を提供します。この詳細な追跡は、MLモデルのデバッグ、最適化、性能理解に極めて重要です。
というわけで、雑に言えば初期のLangSmithのように、LLMにおける各種処理の実行結果をロギングし、どのような入力内容でどのような出力を得たのかを処理ごとに確認できます。
実際にどのようなものか、見た方が早いと思うので実行してみます。
Step1. パッケージインストール
MLflow含めて、パッケージをインストールします。
mlflow
のバージョン指定を行っていませんが、2024/6/17時点で2.14.0がインストールされるはずです。
%pip install -U langchain==0.2.1 langchain-community==0.2.1 mlflow-skinny[databricks]
dbutils.library.restartPython()
Step2. 簡単に試す
純粋にトレーシングを行うことでどのような内容を追えるかを確認してみます。
MLflow TracingはFluent APIとMLflow Client APIの2種類のインターフェースが提供されています。
それぞれの違いは以下のドキュメントから比較表を見るのがよいと思いますが、単純利用においてはFluent APIが推奨されています。
Fluent APIはtrace
関数デコレータが提供されており、これを関数につけることで簡単にトレース対象にできます。
import mlflow
@mlflow.trace(name="func", span_type="TYPE", attributes={"key": "value"})
def my_function(x, y):
return x + y
my_function(1, 2)
関数を実行すると、Notebook上でトレーシングのUIが確認できます。
実行時間や入力内容、出力が表示されます。
また、MLflowのエクスペリメント内でも結果が記録されます。
エクスペリメントUIを確認すると、「トレース」タブが!
以下のようにトレース対象の処理が記録されています。
なお、リンクになっているリクエストIDをクリックすると、ノートブック側で表示されたUIと同じものが表示されます。
trace
デコレータを付与した関数はネストすることができるため、デコレータを付与した関数内で、他の付与した関数を呼び出すと、階層構造となって記録されます。
@mlflow.trace(name="func_root", span_type="TYPE", attributes={"key": "value"})
def my_function_root(x, y):
my_function(x, y)
my_function(x + 1, y + 1)
my_function(x + 2, y + 2)
my_function(x + 3, y + 3)
my_function_root(1, 2)
Step3. LangChainと組み合わせて試す
単純な関数実行をトレースしてもあまりメリットはありません。
本来の目的であるGenAI(LLM)で使ってみましょう。
今回は、以下のドキュメントにあるLangChainでの自動トレーシングを試してみます。
import os
import mlflow
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models.databricks import ChatDatabricks
from langchain_core.output_parsers import StrOutputParser
# langchainでの自動ロギングを有効化。特段のパラメータが指定が無いと、トレースだけ有効化される
mlflow.langchain.autolog()
# 以前作成したDatabricks Model Serving Endpointを利用してLLMを利用
# 任意のModel Serving Endpointを利用してください
llm = ChatDatabricks(
target_uri="databricks",
endpoint="mistral-7b-instruct-v03-endpoint",
temperature=0.1,
)
prompt_template = (
"Imagine that you are {person}, and you are embodying their manner of answering questions posed to them. "
"While answering, attempt to mirror their conversational style, their wit, and the habits of their speech "
"and prose. You will emulate them as best that you can, attempting to distill their quirks, personality, "
"and habits of engagement to the best of your ability. Feel free to fully embrace their personality, whether "
"aspects of it are not guaranteed to be productive or entirely constructive or inoffensive."
"The question you are asked, to which you will reply as that person, is: {question}"
)
prompt = ChatPromptTemplate.from_template(prompt_template)
# Chain作成
chain = prompt | llm | StrOutputParser()
# テストその1
chain.invoke(
{
"person": "Richard Feynman",
"question": "Why should we colonize Mars instead of Venus?",
}
)
# テストその2
chain.invoke(
{
"person": "Linus Torvalds",
"question": "Can I just set everyone's access to sudo to make things easier?",
}
)
実行すると、以下のようなUIが表示されます。
(また、MLflowエクスペリメントUI上でも記録されます)
LangSmithのようにRunnableSequenceの単位に処理が記録されました。
UI自体は非常にシンプルですが、最低限の確認したい情報が見れるようになっていると思います。
また、トレースのオーバーヘッドもそんなに感じませんでした。
Step4. トレーシングを無効化する
本番運用においては設定したトレーシングを無効化したいときもあります。
その場合は、以下に記載があるようにmlflow.tracing.disable
APIをコールすればよいようです。
mlflow.tracing.disable()
my_function(1, 2)
トレーシングのUIが表示されず、記録されなくなりました。
まとめ
MLflow Tracingを軽く試してみました。
初期リリースに関わらず、今回試せてないSPANなど多くの機能が提供されています。
LLMを使ったCompound AI systemを構築する上でトレーシング機能は非常に大事だと思いっています。これが外部サービスを使う必要なくMlflowで完結できるのは非常に便利なので、うまく活用していきたいと思っています。