1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

sklearnモデルを使ってspark DataFrameに対して予測を実行する

Last updated at Posted at 2023-01-11

databricksでsklearnの学習済みモデルをspark DataFrameに適用する方法が、意外と検索しても出てこなかったので備忘録。サンプルコードはmlflowからsklearnの学習済みモデルを読み込む想定なので、再利用される場合は適宜pickle.load()なりに変更してください。

pandas_udf

Spark 2.3以降であればpandas_udfが使えます。行ごとに展開するudfよりもpandas_udfの方が高速らしいです。
Spark 3.0以降ではpandas_udfで関数を定義する際、型ヒントの記述が必須。

from pyspark.sql.functions import pandas_udf
import pandas as pd
import mlflow

@pandas_udf('double')
def predict(*cols: pd.Series) -> pd.Series:
    model_path = f'runs://{run.info.run_id}/model'
    model = mlflow.sklearn.load_model(model_path)
    X = pd.concat(cols, axis=1)
    return pd.Series(model.predict_proba(X)[:, 1])

prediction_df = spark_df.withColumn('pred', predict(spark_df.columns))

@pandas_udf('double')でリターン値のデータ型を指定しています。モデルによって予測値のデータ型に合わせる必要があります。

pandas_udf with iterator

pandas_udfは指定されたバッチサイズごとにループ処理をしています。この方法だとループする度にモデルの読み込みが発生するため、ファイルサイズが大きいモデルの場合はオーバーヘッドが高くなります。mlflow.sklearn.load_model()を関数の外ですれば良いだけなんですが、それはそれでイケテナイです。
pandas_udfをイテレータで定義すると、モデルの読み込みが1度で済みます。ただしSpark 3.0以降でのみ対応。

from typing import Iterator, Tuple
@pandas_udf('double')
def predict(iterator: Iterator[pd.DataFrame]) ->  Iterator[pd.DataFrame]:
    model_path = f'runs://{run.info.run_id}/model'
    model = mlflow.sklearn.load_model(model_path)
    for features in iterator:
        pdf = pd.concat([features], axis = 1)
        yield pd.Series(model.predict(pdf))
prediction_df = spark_df.withColumn('pred', predict(spark_df.columns))

なお、pandas_udfのバッチサイズはspark.conf.get(‘spark.sql.execution.arrow.maxRecordsPerBatch’)で確認可能です。。デフォルトは10,000。

mapInPandas

Spark 3.0以降に登場した、pandas DataFrameを入力値にした関数をそのままspark DataFrameに適用するメソッド。ただしリターン値となるspark DataFrameのスキーマを引数として渡す必要があるため、若干面倒。spark(静的言語のJAVA)にそのまま渡すイメージっぽい(あくまで個人の主観です)。

from typing import Iterator, Tuple
from pyspark.sql.functions import lit
from pyspark.sql.types import DoubleType

@pandas_udf('double')
def predict(iterator: Iterator[pd.DataFrame]) ->  Iterator[pd.DataFrame]:
    model_path = f'runs://{run.info.run_id}/model'
    model = mlflow.sklearn.load_model(model_path)
    for features in iterator:
        yield pd.concat([features, pd.Series(model.predict(features), name = 'prediction'))], axis = 1)
# 予測値のカラムをデータ値をNoneとして定義し、そのスキーマ情報を取得
schema = spark_df.withColumn('prediction', lit(None).cast(DoubleType())).schema
spark_df = spark_df.mapInPandas(predict, schema)

参考

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?