databricksでsklearnの学習済みモデルをspark DataFrameに適用する方法が、意外と検索しても出てこなかったので備忘録。サンプルコードはmlflowからsklearnの学習済みモデルを読み込む想定なので、再利用される場合は適宜pickle.load()なりに変更してください。
pandas_udf
Spark 2.3以降であればpandas_udfが使えます。行ごとに展開するudfよりもpandas_udfの方が高速らしいです。
Spark 3.0以降ではpandas_udfで関数を定義する際、型ヒントの記述が必須。
from pyspark.sql.functions import pandas_udf
import pandas as pd
import mlflow
@pandas_udf('double')
def predict(*cols: pd.Series) -> pd.Series:
model_path = f'runs://{run.info.run_id}/model'
model = mlflow.sklearn.load_model(model_path)
X = pd.concat(cols, axis=1)
return pd.Series(model.predict_proba(X)[:, 1])
prediction_df = spark_df.withColumn('pred', predict(spark_df.columns))
@pandas_udf('double')
でリターン値のデータ型を指定しています。モデルによって予測値のデータ型に合わせる必要があります。
pandas_udf with iterator
pandas_udfは指定されたバッチサイズごとにループ処理をしています。この方法だとループする度にモデルの読み込みが発生するため、ファイルサイズが大きいモデルの場合はオーバーヘッドが高くなります。mlflow.sklearn.load_model()
を関数の外ですれば良いだけなんですが、それはそれでイケテナイです。
pandas_udfをイテレータで定義すると、モデルの読み込みが1度で済みます。ただしSpark 3.0以降でのみ対応。
from typing import Iterator, Tuple
@pandas_udf('double')
def predict(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
model_path = f'runs://{run.info.run_id}/model'
model = mlflow.sklearn.load_model(model_path)
for features in iterator:
pdf = pd.concat([features], axis = 1)
yield pd.Series(model.predict(pdf))
prediction_df = spark_df.withColumn('pred', predict(spark_df.columns))
なお、pandas_udfのバッチサイズはspark.conf.get(‘spark.sql.execution.arrow.maxRecordsPerBatch’)
で確認可能です。。デフォルトは10,000。
mapInPandas
Spark 3.0以降に登場した、pandas DataFrameを入力値にした関数をそのままspark DataFrameに適用するメソッド。ただしリターン値となるspark DataFrameのスキーマを引数として渡す必要があるため、若干面倒。spark(静的言語のJAVA)にそのまま渡すイメージっぽい(あくまで個人の主観です)。
from typing import Iterator, Tuple
from pyspark.sql.functions import lit
from pyspark.sql.types import DoubleType
@pandas_udf('double')
def predict(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
model_path = f'runs://{run.info.run_id}/model'
model = mlflow.sklearn.load_model(model_path)
for features in iterator:
yield pd.concat([features, pd.Series(model.predict(features), name = 'prediction'))], axis = 1)
# 予測値のカラムをデータ値をNoneとして定義し、そのスキーマ情報を取得
schema = spark_df.withColumn('prediction', lit(None).cast(DoubleType())).schema
spark_df = spark_df.mapInPandas(predict, schema)