0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Sparkのユーザー定義関数、高階関数

Last updated at Posted at 2024-03-27

2024/4/12に翔泳社よりApache Spark徹底入門を出版します!

書籍のサンプルノートブックをウォークスルーしていきます。Python/Chapter05/5-1 Spark SQL & UDFsの前半となります。

翻訳ノートブックのリポジトリはこちら。

ノートブックはこちら

ユーザー定義関数

Apache Sparkでは数多くの関数を提供していますが、Sparkの柔軟性によってデータエンジニアやデータサイエンティストは自分の関数を定義することができます(すなわち、user-defined functionあるいはUDF)。

from pyspark.sql.types import LongType

# 3乗関数の作成
def cubed(s):
  return s * s * s

# UDFの登録
spark.udf.register("cubed", cubed, LongType())

# 一時ビューの作成
spark.range(1, 9).createOrReplaceTempView("udf_test")
spark.sql("SELECT id, cubed(id) AS id_cubed FROM udf_test").show()
+---+--------+
| id|id_cubed|
+---+--------+
|  1|       1|
|  2|       8|
|  3|      27|
|  4|      64|
|  5|     125|
|  6|     216|
|  7|     343|
|  8|     512|
+---+--------+

Pandas UDFを用いたPySpark UDFのスピードアップと分散処理

PySpark UDFで不可避な問題の一つが、Scala UDFよりも遅いということです。これは、PySpark UDFでは、処理コストが非常に高いJVMとPythonの間でのデータ移動が必要なためです。この問題を解決するために、pandas UDF(ベクトル化UDFとも呼ばれます)がApache Spark 2.3の一部として導入されました。これは、データ転送にApache Sparkを活用し、データの操作でpandasを活用します。デコレーターとしてpandas_udfキーワードを用いるか、関数自身をラッピングするためにpandas UDFを定義するだけです。データがApache Arrowフォーマットになると、Pythonプロセスで利用可能なフォーマットになっているので、シリアライズ/pickleの必要がなくなります。個々の入力を行ごとに操作するのではなく、pandasのシリーズやデータフレームを操作することになります(すなわち、ベクトル化処理)。

import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType

# 3乗関数の宣言 
def cubed(a: pd.Series) -> pd.Series:
    return a * a * a

# 3乗関数に対するpandas UDFの作成 
cubed_udf = pandas_udf(cubed, returnType=LongType())

pandasデータフレームの使用

# Pandasシリーズの作成
x = pd.Series([1, 2, 3])

# ローカルのPandasデータに対して実行されるpandas_udfの関数
print(cubed(x))
0     1
1     8
2    27
dtype: int64

Sparkデータフレームの使用

# Sparkデータフレームの作成
df = spark.range(1, 4)

# Sparkベクトル化UDFとして関数を実行
df.select("id", cubed_udf(col("id"))).show()
+---+---------+
| id|cubed(id)|
+---+---------+
|  1|        1|
|  2|        8|
|  3|       27|
+---+---------+

データフレームやSpark SQLにおける高階関数

複雑なデータタイプはシンプルなデータタイプの集合体なので、直接複雑なデータタイプを操作したくなるものです。記事Introducing New Built-in and Higher-Order Functions for Complex Data Types in Apache Spark 2.4で言及したように、複雑なデータ型の操作では、2つの典型的なソリューションがあります。

  1. 以下のコードで示しているように、ネストされた構造を個々の行に分割し、何かしらの関数を適用し、ネストされた構造を再構築する(オプション1をご覧ください)
  2. ユーザー定義関数(UDF)を構築する
# 配列型データセットの作成
arrayData = [[1, (1, 2, 3)], [2, (2, 3, 4)], [3, (3, 4, 5)]]

# スキーマの作成
from pyspark.sql.types import *
arraySchema = (StructType([
      StructField("id", IntegerType(), True), 
      StructField("values", ArrayType(IntegerType()), True)
      ]))

# データフレームの作成
df = spark.createDataFrame(spark.sparkContext.parallelize(arrayData), arraySchema)
df.createOrReplaceTempView("table")
df.printSchema()
df.show()
root
 |-- id: integer (nullable = true)
 |-- values: array (nullable = true)
 |    |-- element: integer (containsNull = true)

+---+---------+
| id|   values|
+---+---------+
|  1|[1, 2, 3]|
|  2|[2, 3, 4]|
|  3|[3, 4, 5]|
+---+---------+

オプション1: ExplodeとCollect

このネストされたSQL文では、最初に値の中の個々の要素(value)に対応する(idを持つ)新規行を作成するexplode(values)を実行します。

spark.sql("""
SELECT id, collect_list(value + 1) AS newValues
  FROM  (SELECT id, explode(values) AS value
        FROM table) x
 GROUP BY id
""").show()
+---+---------+
| id|newValues|
+---+---------+
|  1|[2, 3, 4]|
|  2|[3, 4, 5]|
|  3|[4, 5, 6]|
+---+---------+

オプション2: ユーザー定義関数

同じタスク(valuesのそれぞれの要素の値に1を足す)を実行するために、加算のオペレーションを実行するためにそれぞれの要素(value)に対するイテレーションのためにmapを用いるユーザー定義関数(UDF)を作成することもできます。

from pyspark.sql.types import IntegerType
from pyspark.sql.types import ArrayType

# UDFの作成
def addOne(values):
  return [value + 1 for value in values]

# UDFの登録
spark.udf.register("plusOneIntPy", addOne, ArrayType(IntegerType()))  

# データのクエリー
spark.sql("SELECT id, plusOneIntPy(values) AS values FROM table").show()
+---+---------+
| id|   values|
+---+---------+
|  1|[2, 3, 4]|
|  2|[3, 4, 5]|
|  3|[4, 5, 6]|
+---+---------+

高階関数

上述したビルトインの関数に加え、引数として匿名のラムダ関数を受け取る高階関数があります。

from pyspark.sql.types import *
schema = StructType([StructField("celsius", ArrayType(IntegerType()))])

t_list = [[35, 36, 32, 30, 40, 42, 38]], [[31, 32, 34, 55, 56]]
t_c = spark.createDataFrame(t_list, schema)
t_c.createOrReplaceTempView("tC")

# データフレームの表示
t_c.show()
+--------------------+
|             celsius|
+--------------------+
|[35, 36, 32, 30, ...|
|[31, 32, 34, 55, 56]|
+--------------------+

Transform

transform(array<T>, function<T, U>): array<U>

transform関数は、(map関数と同じように)入力配列のそれぞれの要素に関数を適用することで配列を生成します。

# 気温の配列に対して摂氏から華氏を計算
spark.sql("""SELECT celsius, transform(celsius, t -> ((t * 9) div 5) + 32) as fahrenheit FROM tC""").show()
+--------------------+--------------------+
|             celsius|          fahrenheit|
+--------------------+--------------------+
|[35, 36, 32, 30, ...|[95, 96, 89, 86, ...|
|[31, 32, 34, 55, 56]|[87, 89, 93, 131,...|
+--------------------+--------------------+

Filter

filter(array<T>, function<T, Boolean>): array<T>

filter関数はboolean関数がtrueになる要素を持つ配列を生成します。

# 気温配列を temperatures > 38C でフィルタリング
spark.sql("""SELECT celsius, filter(celsius, t -> t > 38) as high FROM tC""").show()
+--------------------+--------+
|             celsius|    high|
+--------------------+--------+
|[35, 36, 32, 30, ...|[40, 42]|
|[31, 32, 34, 55, 56]|[55, 56]|
+--------------------+--------+

Exists

exists(array<T>, function<T, V, Boolean>): Boolean

exists関数はboolean関数が入力配列のいずれかの要素でtrueになる場合にはtrueを返します。

# 気温の配列に38Cが含まれるかどうか
spark.sql("""
SELECT celsius, exists(celsius, t -> t = 38) as threshold
FROM tC
""").show()
+--------------------+---------+
|             celsius|threshold|
+--------------------+---------+
|[35, 36, 32, 30, ...|     true|
|[31, 32, 34, 55, 56]|    false|
+--------------------+---------+

Reduce

reduce(array<T>, B, function<B, T, B>, function<B, R>)

reduce関数はfunction<B, T, B>を用いてバッファBに配列をマージし、最後のバッファに最終関数function<B, R>を適用することで、配列を単一の値に集約します。

# 平均気温を計算し、Fに変換
spark.sql("""
SELECT celsius, 
       reduce(
          celsius, 
          0, 
          (t, acc) -> t + acc, 
          acc -> (acc div size(celsius) * 9 div 5) + 32
        ) as avgFahrenheit 
  FROM tC
""").show()
+--------------------+-------------+
|             celsius|avgFahrenheit|
+--------------------+-------------+
|[35, 36, 32, 30, ...|           96|
|[31, 32, 34, 55, 56]|          105|
+--------------------+-------------+

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?