久しぶりのLLM以外の記事ですが、どちらかといえばこちらが本業です。
導入
Apache Spark 3.5がリリースされました。
下記のDatabricks公式blogでも取り上げられています。
これを読んでいていると、PysparkのTesting APIというものに目が引かれました。
類似のモジュールは既にあったと思うのですが、公式が出してくれると地味に捗ります。
というわけで、下のドキュメントを基に、ウォークスルーしてみました。
環境はいつものようにDatabricksを使います。DBRは14.0です。
Step1. ビルトイン関数を試す
ドキュメントの内容ほぼそのままを実行してきます。
最初はデータフレーム同士の比較を行う関数assertDataFrameEqual
でテストします。
import pyspark.testing
from pyspark.testing.utils import assertDataFrameEqual
# Example 1
df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2) # pass, DataFrames are identical
# Example 2
df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol
どちらのサンプルもエラーなく処理が完了します。
Example2はコメントにあるように、浮動小数は完全一致でなくても通るようですね。
次はスキーマの同一性をテストするassertSchemaEqual
を使うサンプル。
from pyspark.testing.utils import assertSchemaEqual
from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType
s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
assertSchemaEqual(s1, s2) # pass, schemas are identical
こちらもエラーなく処理が完了します。
さて、試しにエラーを起こしてみましょう。
最後のサンプルで、カラムの名前をnamesからnames1に変更して実行してみます。
PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
--- actual
+++ expected
- StructType([StructField('names', ArrayType(DoubleType(), True), True)])
+ StructType([StructField('names1', ArrayType(DoubleType(), True), True)])
? +
無事(?)エラーが出ました。
違いが見やすく出力されます。
Step2. unittestの中で使う
こちらの内容を実行します。
また、意図的にテストが失敗するように一部修正しています。
# データとテスト対象の関数を準備
import pyspark.sql.functions as F
sample_data = [
{"name": "John D.", "age": 30},
{"name": "Alice G.", "age": 25},
{"name": "Bob T.", "age": 35},
{"name": "Eve A.", "age": 28},
]
df = spark.createDataFrame(sample_data)
# Remove additional spaces in name
def remove_extra_spaces(df, column_name):
df_transformed = df.withColumn(
column_name, F.regexp_replace(F.col(column_name), "\\s+", " ")
)
return df_transformed
# テストの実行
import unittest
# Define unit test
class TestTranformation(unittest.TestCase):
def test_single_space(self):
sample_data = [
{"name": "John D.", "age": 30},
{"name": "Alice G.", "age": 25},
{"name": "Bob T.", "age": 35},
{"name": "Eve A.", "age": 28},
]
# Create a Spark DataFrame
original_df = spark.createDataFrame(sample_data)
# Apply the transformation function from before
transformed_df = remove_extra_spaces(original_df, "name")
expected_data = [
{"name": "John D.", "age": 31},
{"name": "Alice A.", "age": 25},
{"name": "Bob T.", "age": 40},
{"name": "Eve A.", "age": 28},
]
expected_df = spark.createDataFrame(expected_data)
assertDataFrameEqual(transformed_df, expected_df)
suite = unittest.TestLoader().loadTestsFromTestCase(TestTranformation)
runner = unittest.TextTestRunner(
verbosity=0,
)
runner.run(suite)
======================================================================
FAIL: test_single_space (__main__.TestTranformation)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/root/.ipykernel/1109/command-2200363102628332-3245114994", line 28, in test_single_space
assertDataFrameEqual(transformed_df, expected_df)
File "/databricks/spark/python/pyspark/instrumentation_utils.py", line 48, in wrapper
res = func(*args, **kwargs)
File "/databricks/spark/python/pyspark/testing/utils.py", line 548, in assertDataFrameEqual
assert_rows_equal(df_list, expected_list)
File "/databricks/spark/python/pyspark/testing/utils.py", line 528, in assert_rows_equal
raise PySparkAssertionError(
pyspark.errors.exceptions.base.PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 75.00000 % )
--- actual
+++ expected
- Row(age=25, name='Alice G.')
? ^
+ Row(age=25, name='Alice A.')
? ^
********************
- Row(age=30, name='John D.')
? ^
+ Row(age=31, name='John D.')
? ^
********************
- Row(age=35, name='Bob T.')
? ^^
+ Row(age=40, name='Bob T.')
? ^^
********************
----------------------------------------------------------------------
Ran 1 test in 0.720s
FAILED (failures=1)
unittestでも問題なくassertDataFrameEqual
が動作してますね。
結果もexpected dataとの差異をわかるように出力してくれています。
その他
pytestでも試しましたが、同様に動きました。
ドキュメントにも記載されているように、特定のテストモジュールに依存していないので好きなテストフレームワークで使うことができますね。
まとめ
こういう機能は地味ですが実務において役に立つので、実装されたのは嬉しい。