PySpark で when をチェーンするコードを書いていたときに
「これって SQL と同じように先に書いた when が優先される?」
「メソッドチェーンだから後ろに書いた when で上書きされる?」
と不安になったので、実際に検証コードを書いて調べた。
ダミーデータ
df = spark.createDataFrame([(1,),(2,),(3,)], schema=('val',))
display(df)
| val |
|---|
| 1 |
| 2 |
| 3 |
Spark SQL の場合
# Spark SQL から触れるように一時テーブルとして登録
df.registerTempTable('tmp')
SELECT
val,
CASE
WHEN val <= 1 THEN 'label_1'
WHEN val <= 2 THEN 'label_2'
ELSE 'label_3'
END AS label
FROM tmp
| val | label |
|---|---|
| 1 | label_1 |
| 2 | label_2 |
| 3 | label_3 |
SQL の場合は当然、先に書いた WHEN の条件が優先される。
PySpark の場合
from pyspark.sql import functions as F
df_label = df.withColumn('label',
F.when(F.col('val') <= 1, 'label_1')
.when(F.col('val') <= 2, 'label_2')
.otherwise('label_3')
)
display(df_label)
| val | label |
|---|---|
| 1 | label_1 |
| 2 | label_2 |
| 3 | label_3 |
PySpark で when をチェーンした場合でも、Spark SQL と同様に先に書いた when の条件が優先されるらしい。