LoginSignup
1

Apache Spark 3.5で導入されたTesting APIを試す

Last updated at Posted at 2023-09-17

久しぶりのLLM以外の記事ですが、どちらかといえばこちらが本業です。

導入

Apache Spark 3.5がリリースされました。

下記のDatabricks公式blogでも取り上げられています。

これを読んでいていると、PysparkのTesting APIというものに目が引かれました。
類似のモジュールは既にあったと思うのですが、公式が出してくれると地味に捗ります。

というわけで、下のドキュメントを基に、ウォークスルーしてみました。

環境はいつものようにDatabricksを使います。DBRは14.0です。

Step1. ビルトイン関数を試す

ドキュメントの内容ほぼそのままを実行してきます。

最初はデータフレーム同士の比較を行う関数assertDataFrameEqualでテストします。

import pyspark.testing
from pyspark.testing.utils import assertDataFrameEqual

# Example 1
df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2)  # pass, DataFrames are identical
# Example 2
df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2, rtol=1e-1)  # pass, DataFrames are approx equal by rtol

どちらのサンプルもエラーなく処理が完了します。
Example2はコメントにあるように、浮動小数は完全一致でなくても通るようですね。

次はスキーマの同一性をテストするassertSchemaEqualを使うサンプル。

from pyspark.testing.utils import assertSchemaEqual
from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType

s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])

assertSchemaEqual(s1, s2)  # pass, schemas are identical

こちらもエラーなく処理が完了します。

さて、試しにエラーを起こしてみましょう。
最後のサンプルで、カラムの名前をnamesからnames1に変更して実行してみます。

結果
PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
--- actual
+++ expected
- StructType([StructField('names', ArrayType(DoubleType(), True), True)])
+ StructType([StructField('names1', ArrayType(DoubleType(), True), True)])
?                               +

無事(?)エラーが出ました。
違いが見やすく出力されます。

Step2. unittestの中で使う

こちらの内容を実行します。
また、意図的にテストが失敗するように一部修正しています。

# データとテスト対象の関数を準備

import pyspark.sql.functions as F

sample_data = [
    {"name": "John    D.", "age": 30},
    {"name": "Alice   G.", "age": 25},
    {"name": "Bob  T.", "age": 35},
    {"name": "Eve   A.", "age": 28},
]

df = spark.createDataFrame(sample_data)

# Remove additional spaces in name
def remove_extra_spaces(df, column_name):
    df_transformed = df.withColumn(
        column_name, F.regexp_replace(F.col(column_name), "\\s+", " ")
    )

    return df_transformed
# テストの実行

import unittest

# Define unit test
class TestTranformation(unittest.TestCase):
    def test_single_space(self):
        sample_data = [
            {"name": "John    D.", "age": 30},
            {"name": "Alice   G.", "age": 25},
            {"name": "Bob  T.", "age": 35},
            {"name": "Eve   A.", "age": 28},
        ]

        # Create a Spark DataFrame
        original_df = spark.createDataFrame(sample_data)

        # Apply the transformation function from before
        transformed_df = remove_extra_spaces(original_df, "name")

        expected_data = [
            {"name": "John D.", "age": 31},
            {"name": "Alice A.", "age": 25},
            {"name": "Bob T.", "age": 40},
            {"name": "Eve A.", "age": 28},
        ]

        expected_df = spark.createDataFrame(expected_data)

        assertDataFrameEqual(transformed_df, expected_df)

suite = unittest.TestLoader().loadTestsFromTestCase(TestTranformation)
runner = unittest.TextTestRunner(
    verbosity=0,
)
runner.run(suite)
結果
======================================================================
FAIL: test_single_space (__main__.TestTranformation)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/root/.ipykernel/1109/command-2200363102628332-3245114994", line 28, in test_single_space
    assertDataFrameEqual(transformed_df, expected_df)
  File "/databricks/spark/python/pyspark/instrumentation_utils.py", line 48, in wrapper
    res = func(*args, **kwargs)
  File "/databricks/spark/python/pyspark/testing/utils.py", line 548, in assertDataFrameEqual
    assert_rows_equal(df_list, expected_list)
  File "/databricks/spark/python/pyspark/testing/utils.py", line 528, in assert_rows_equal
    raise PySparkAssertionError(
pyspark.errors.exceptions.base.PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 75.00000 % )
--- actual
+++ expected

- Row(age=25, name='Alice G.')
?                         ^

+ Row(age=25, name='Alice A.')
?                         ^

********************

- Row(age=30, name='John D.')
?          ^

+ Row(age=31, name='John D.')
?          ^

********************

- Row(age=35, name='Bob T.')
?         ^^

+ Row(age=40, name='Bob T.')
?         ^^

********************


----------------------------------------------------------------------
Ran 1 test in 0.720s

FAILED (failures=1)

unittestでも問題なくassertDataFrameEqualが動作してますね。
結果もexpected dataとの差異をわかるように出力してくれています。

その他

pytestでも試しましたが、同様に動きました。
ドキュメントにも記載されているように、特定のテストモジュールに依存していないので好きなテストフレームワークで使うことができますね。

まとめ

こういう機能は地味ですが実務において役に立つので、実装されたのは嬉しい。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
1