2024/4/12に翔泳社よりApache Spark徹底入門を出版します!
書籍のサンプルノートブックをウォークスルーしていきます。Python/Chapter07/7-4 Vectorized UDFs
となります。
翻訳ノートブックのリポジトリはこちら。
ノートブックはこちら
UDF、ベクトライズドUDF、ビルトインのメソッドのパフォーマンスを比較してみましょう。
ダミーデータを生成することからスタートしましょう。
from pyspark.sql.types import *
from pyspark.sql.functions import col, count, rand, collect_list, explode, struct, count, pandas_udf
df = (spark
.range(0, 10 * 1000 * 1000)
.withColumn("id", (col("id") / 1000).cast("integer"))
.withColumn("v", rand()))
df.cache()
df.count()
10000000
display(df)
列の値を1増加
データフレームのそれぞれの値に1を加算するシンプルな例からスタートします。
PySpark UDF
@udf("double")
def plus_one(v):
return v + 1
%timeit -n1 -r1 df.withColumn("v", plus_one(df.v)).agg(count(col("v"))).show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
2.71 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
別の構文 (SQL名前空間で利用可能)
from pyspark.sql.types import DoubleType
def plus_one(v):
return v + 1
spark.udf.register("plus_one_udf", plus_one, DoubleType())
%timeit -n1 -r1 df.selectExpr("id", "plus_one_udf(v) as v").agg(count(col("v"))).show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
2.42 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Scala UDF
うわーっ!それぞれの値に1を加算するので時間がかかっています。Scala UDFでどのくらいの時間になるのかを見てみましょう。
df.createOrReplaceTempView("df")
%scala
import org.apache.spark.sql.functions._
val df = spark.table("df")
def plusOne: (Double => Double) = { v => v+1 }
val plus_one = udf(plusOne)
import org.apache.spark.sql.functions._
df: org.apache.spark.sql.DataFrame = [id: int, v: double]
plusOne: Double => Double
plus_one: org.apache.spark.sql.expressions.UserDefinedFunction = SparkUserDefinedFunction($Lambda$9933/400593915@6bc5f2c6,DoubleType,List(Some(class[value[0]: double])),Some(class[value[0]: double]),None,false,true)
%scala
df.withColumn("v", plus_one($"v"))
.agg(count(col("v")))
.show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
ワオ!Scala UDFの方がはるかに高速です。しかし、Spark 2.3時点では、Pythonでの処理を高速化する助けとなるベクトライズドUDFが利用できます。
ベクトライズドUDFは処理を高速化するためにApache Arrowを活用します。どれだけ処理時間の改善になるのかを見てみましょう。
Apache Arrowは、JVMとPythonプロセス間のデータ転送を効率的に行うためにSparkで利用されるインメモリの列指向データフォーマットです。詳細はこちらをご覧ください。
Apache Arrowが有効化されている場合とされていない場合とで、SparkデータフレームからPandasへの変換にどのくらい時間がかかるのかを見てみましょう。
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
%timeit -n1 -r1 df.toPandas()
1.37 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
spark.conf.set("spark.sql.execution.arrow.enabled", "false")
%timeit -n1 -r1 df.toPandas()
23.4 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
ベクトライズドUDF
@pandas_udf("double")
def vectorized_plus_one(v):
return v + 1
%timeit -n1 -r1 df.withColumn("v", vectorized_plus_one(df.v)).agg(count(col("v"))).show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
2.16 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
いい感じです!Scala UDFほどではありませんが、少なくとも通常のPython UDFよりは優れています!
Pandas UDFでは別の構文がいくつか存在します。
from pyspark.sql.functions import pandas_udf
def vectorized_plus_one(v):
return v + 1
vectorized_plus_one_udf = pandas_udf(vectorized_plus_one, "double")
%timeit -n1 -r1 df.withColumn("v", vectorized_plus_one_udf(df.v)).agg(count(col("v"))).show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
1.45 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
ビルトインのメソッド
ビルトインのメソッドとUDFのパフォーマンスを比較してみましょう。
from pyspark.sql.functions import lit
%timeit -n1 -r1 df.withColumn("v", df.v + lit(1)).agg(count(col("v"))).show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
406 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
ビルトインメソッドが使える場合には、ビルトインメソッドを使う方がはるかに高速です。
subtract meanの計算
ここまでは、スカラーの戻り値を取り扱ってきました。ここでは、グルーピングされたUDFを活用します。
PySpark UDF
from pyspark.sql import Row
import pandas as pd
@udf(ArrayType(df.schema))
def subtract_mean(rows):
vs = pd.Series([r.v for r in rows])
vs = vs - vs.mean()
return [Row(id=rows[i]["id"], v=float(vs[i])) for i in range(len(rows))]
%timeit -n1 -r1 (df.groupby("id").agg(collect_list(struct(df["id"], df["v"])).alias("rows")).withColumn("new_rows", subtract_mean(col("rows"))).withColumn("new_row", explode(col("new_rows"))).withColumn("id", col("new_row.id")).withColumn("v", col("new_row.v")).agg(count(col("v"))).show())
+--------+
|count(v)|
+--------+
|10000000|
+--------+
20.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Vectorized UDF
def vectorized_subtract_mean(pdf: pd.Series) -> pd.Series:
return pdf.assign(v=pdf.v - pdf.v.mean())
%timeit -n1 -r1 df.groupby("id").applyInPandas(vectorized_subtract_mean, df.schema).agg(count(col("v"))).show()
+--------+
|count(v)|
+--------+
|10000000|
+--------+
5.61 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)