2
0

More than 1 year has passed since last update.

mlflow.spark.autologによるデータソース(ファイルパス、バージョン)のトラッキング

Posted at

mlflow.spark.autologという機能の存在は知っていたのですが、きちんと使ったことがなかったので使ってみました。

MLflowでモデルをトラッキングする際に使うSparkデータソースの情報をMLflowで記録することができます。Delta Lakeと組み合わせることで、データのバージョンも追跡できる様になります。

関連コンポーネントの説明

Apache Sparkとは?

Apache Sparkは、大規模なデータの高速リアルタイム処理を実現するオープンソースのクラスタコンピューティングフレームワークです。大量なデータを並列で処理することで、非常に高いパフォーマンスを発揮することができます。データ加工だけでなく、機械学習モデルのトレーニングやハイパーパラメーターチューニングを並列処理することが可能です。

MLflowとは?

機械学習モデルのライフサイクル管理のためのフレームワークを提供するソフトウェアです。機械学習のトラッキング、集中管理のためのモデルレジストリといった機能を提供します。Databricksでは、マネージドサービスとしてMLflowを利用できる様になっていますので、Databricksノートブック上でトレーニングした機械学習は自動でトラッキングされます。

Delta Lakeとは?

データレイクに格納されているデータに対して高速なデータ処理、強力なデータガバナンスを提供するストレージレイヤーソフトウェアです。ACIDトランザクションやデータのバージョン管理、インデックス作成機能などを提供します。機械学習の文脈ではデータのバージョン管理が重要な意味を持つことになります。

mlflow.spark.autologとは

以下はマニュアルの翻訳です。

読み込みを行うSparkデータソースのパス、(対応している場合)バージョン、フォーマットの記録の有効(無効)を設定します。このメソッドはスレッドセーフでなく、mlflow-spark JARがアタッチされたSparkSessionが存在していることを前提としています。エグゼキューターではなく、Sparkドライバーからコールされる必要があります(すなわち、Sparkで並列化されている関数からこのメソッド呼び出さないでください)。このAPIはSpark 3.0以降が必要です。

データソースの情報はメモリーにキャッシュされ、(データを読みむ際に存在している場合には)アクティブなMLflowランを含み、以降のすべてのMLflowランに記録されます。<以下略>

サンプルノートブックによる実践

以下ではダミーデータを使ってmlflow.spark.autologを試してみます。

Python
import mlflow.spark
import os
import shutil

# ダミーデータを作成して永続化します
df = spark.createDataFrame([
        (4, "spark i j k"),
        (5, "l m n"),
        (6, "spark hadoop spark"),
        (7, "apache hadoop")], ["id", "text"])
Python
# ドライバーノードに保存します
import tempfile
tempdir = tempfile.mkdtemp()
tempfile_path = os.path.join(tempdir, "my-data-path")

# CSVとして保存します
df.write.csv(tempfile_path, header=True)
print("CSV saved to:", tempfile_path)

Screen Shot 2022-11-22 at 16.15.58.png

以下を実行すると、一つ目のモデル(モデルは空ですが)が記録されます。

Python
# Sparkデータソースのオートロギングを有効化します
mlflow.spark.autolog()

# Sparkデータソースの読み込みを起動するために toPandas() を呼び出します。
# データソースの情報(パスとフォーマット)が現在アクティブなランに記録されます。
# あるいは、現在アクティブなランがない場合には次に作成されたランに記録されます。
with mlflow.start_run() as active_run:
  # SparkデータフレームとしてCSVを読み込みます
  loaded_df = spark.read.csv(tempfile_path,
                header=True, inferSchema=True)

  pandas_df = loaded_df.toPandas()

タグsparkDatasourceInfoに読み込んだデータのパスが記録されます。
Screen Shot 2022-11-22 at 16.18.48.png
Screen Shot 2022-11-22 at 16.17.55.png

バージョン番号が記録されることを確認するために、Delta Lake形式で保存します。

Python
# Delta Lakeで保存するDBFSのパス
delta_dbfs_path = "/tmp/databricks_handson/takaakiyayoidatabrickscom/dummy.delta"

dbutils.fs.rm(delta_dbfs_path, True)

# Delta Lakeで保存します
df.write.format("delta").mode("overwrite").save(delta_dbfs_path)

保存したDelta Lakeのデータを読み込むと同時に、データソースを記録します。

Python
# Sparkデータソースの読み込みを起動するために toPandas() を呼び出します。
# データソースの情報(パスとフォーマット)が現在アクティブなランに記録されます。
# あるいは、現在アクティブなランがない場合には次に作成されたランに記録されます。
with mlflow.start_run() as active_run:
  # SparkデータフレームとしてDelta Lakeを読み込みます
  loaded_df = spark.read.format("delta").load(delta_dbfs_path)

  pandas_df = loaded_df.toPandas()

バージョン番号とパスが記録されています。
Screen Shot 2022-11-22 at 16.21.33.png

Delta Lakeのデータを更新してバージョン番号をインクリメントします。

Python
# ダミーデータを更新します
df = spark.createDataFrame([
        (3, "spark test"),
        (4, "spark i j k"),
        (5, "l m n"),
        (6, "spark hadoop spark"),
        (7, "apache hadoop")], ["id", "text"])

# Delta Lakeで保存します
df.write.format("delta").mode("overwrite").save(delta_dbfs_path)

Deltaのバージョン履歴を確認します。

SQL
%sql
DESCRIBE HISTORY "/tmp/databricks_handson/takaakiyayoidatabrickscom/dummy.delta"

Screen Shot 2022-11-22 at 16.23.25.png

最新バージョンのDelta Lakeデータを読み込んで、データソースを記録します。

Python
# Sparkデータソースの読み込みを起動するために toPandas() を呼び出します。
# データソースの情報(パスとフォーマット)が現在アクティブなランに記録されます。
# あるいは、現在アクティブなランがない場合には次に作成されたランに記録されます。
with mlflow.start_run() as active_run:
  # SparkデータフレームとしてDelta Lakeを読み込みます
  loaded_df = spark.read.format("delta").load(delta_dbfs_path)

  pandas_df = loaded_df.toPandas()

バージョン1のデータソースが記録されています。
Screen Shot 2022-11-22 at 16.24.28.png

この様に、それぞれのトレーニング(MLflowラン)でどのバージョンのデータを使用したのが追跡されていることがわかります。
Screen Shot 2022-11-22 at 16.25.14.png

機械学習モデルをトレーニングする際には、どの時点のデータを使ってトレーニングしたのかという情報は、再現性確保の観点で重要です。MLflowとSpark、Deltaを組み合わせることでこの様な情報を容易に追跡できる様になります。

Databricks 無料トライアル

Databricks 無料トライアル

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0