PySpark で when
をチェーンするコードを書いていたときに
「これって SQL と同じように先に書いた when
が優先される?」
「メソッドチェーンだから後ろに書いた when
で上書きされる?」
と不安になったので、実際に検証コードを書いて調べた。
ダミーデータ
df = spark.createDataFrame([(1,),(2,),(3,)], schema=('val',))
display(df)
val |
---|
1 |
2 |
3 |
Spark SQL の場合
# Spark SQL から触れるように一時テーブルとして登録
df.registerTempTable('tmp')
SELECT
val,
CASE
WHEN val <= 1 THEN 'label_1'
WHEN val <= 2 THEN 'label_2'
ELSE 'label_3'
END AS label
FROM tmp
val | label |
---|---|
1 | label_1 |
2 | label_2 |
3 | label_3 |
SQL の場合は当然、先に書いた WHEN
の条件が優先される。
PySpark の場合
from pyspark.sql import functions as F
df_label = df.withColumn('label',
F.when(F.col('val') <= 1, 'label_1')
.when(F.col('val') <= 2, 'label_2')
.otherwise('label_3')
)
display(df_label)
val | label |
---|---|
1 | label_1 |
2 | label_2 |
3 | label_3 |
PySpark で when
をチェーンした場合でも、Spark SQL と同様に先に書いた when
の条件が優先されるらしい。