結構悩みました。
やりたいこと
- Sparkデータフレームに含まれるArrayTypeのカラムに任意の要素が格納されている。
- このカラムの配列の要素をキーごとにカウントして、カウントを格納する新たなカラムとする。カラム名は要素名から構成する。
まず参考になったのはこちら
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
[
{"ID": 1, "tags": ["A", "B", "C"]},
{"ID": 2, "tags": ["A", "D", "E"]},
{"ID": 3, "tags": ["A", "C", "F"]},
]
)
tags = [
x[0]
for x in df.select(F.explode("tags").alias("tags"))
.distinct()
.orderBy("tags")
.collect()
]
df = df.select(
"*",
*[
F.array_contains("tags", tag).alias("tags{}".format(tag)).cast("integer")
for tag in tags
]
)
array_containsを使って、特定の要素を含む場合には要素名を含む列の値を1にしています(例: タグA
を含むかどうかはtagsA
列で表現)。
要素数をカウントするような関数があれば良かったのですが見当たらず。高階関数と組み合わせるかとも思いましたが、最終的にはudfと組み合わせました。
from pyspark.sql import functions as F
df = spark.createDataFrame(
[
{"ID": 1, "tags": ["A", "B", "C", "B", "D"]},
{"ID": 2, "tags": ["A", "D", "E", "E", "A"]},
{"ID": 3, "tags": ["A", "C", "F", "F", "F"]},
]
)
display(df)
unique_tags = [
x[0]
for x in df.select(F.explode("tags").alias("tags"))
.distinct()
.orderBy("tags")
.collect()
]
print(unique_tags)
from pyspark.sql.functions import udf
# リストの特定の要素数をカウントするUDF
@udf("int")
def count_element_udf(target_list, element):
return target_list.count(element)
df = df.select(
"*",
*[count_element_udf("tags", F.lit(tag)).alias(f"count_{tag}") for tag in unique_tags]
)
display(df)
UDFを呼び出す際の二つ目の引数で、F.lit
を指定してリテラル値にしないと列名を参照しようとしてエラーになります。