4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Apache Spark 3.5で導入されたTesting APIを試す

Last updated at Posted at 2023-09-17

久しぶりのLLM以外の記事ですが、どちらかといえばこちらが本業です。

導入

Apache Spark 3.5がリリースされました。

下記のDatabricks公式blogでも取り上げられています。

これを読んでいていると、PysparkのTesting APIというものに目が引かれました。
類似のモジュールは既にあったと思うのですが、公式が出してくれると地味に捗ります。

というわけで、下のドキュメントを基に、ウォークスルーしてみました。

環境はいつものようにDatabricksを使います。DBRは14.0です。

Step1. ビルトイン関数を試す

ドキュメントの内容ほぼそのままを実行してきます。

最初はデータフレーム同士の比較を行う関数assertDataFrameEqualでテストします。

import pyspark.testing
from pyspark.testing.utils import assertDataFrameEqual

# Example 1
df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2)  # pass, DataFrames are identical
# Example 2
df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2, rtol=1e-1)  # pass, DataFrames are approx equal by rtol

どちらのサンプルもエラーなく処理が完了します。
Example2はコメントにあるように、浮動小数は完全一致でなくても通るようですね。

次はスキーマの同一性をテストするassertSchemaEqualを使うサンプル。

from pyspark.testing.utils import assertSchemaEqual
from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType

s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])

assertSchemaEqual(s1, s2)  # pass, schemas are identical

こちらもエラーなく処理が完了します。

さて、試しにエラーを起こしてみましょう。
最後のサンプルで、カラムの名前をnamesからnames1に変更して実行してみます。

結果
PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
--- actual
+++ expected
- StructType([StructField('names', ArrayType(DoubleType(), True), True)])
+ StructType([StructField('names1', ArrayType(DoubleType(), True), True)])
?                               +

無事(?)エラーが出ました。
違いが見やすく出力されます。

Step2. unittestの中で使う

こちらの内容を実行します。
また、意図的にテストが失敗するように一部修正しています。

# データとテスト対象の関数を準備

import pyspark.sql.functions as F

sample_data = [
    {"name": "John    D.", "age": 30},
    {"name": "Alice   G.", "age": 25},
    {"name": "Bob  T.", "age": 35},
    {"name": "Eve   A.", "age": 28},
]

df = spark.createDataFrame(sample_data)

# Remove additional spaces in name
def remove_extra_spaces(df, column_name):
    df_transformed = df.withColumn(
        column_name, F.regexp_replace(F.col(column_name), "\\s+", " ")
    )

    return df_transformed
# テストの実行

import unittest

# Define unit test
class TestTranformation(unittest.TestCase):
    def test_single_space(self):
        sample_data = [
            {"name": "John    D.", "age": 30},
            {"name": "Alice   G.", "age": 25},
            {"name": "Bob  T.", "age": 35},
            {"name": "Eve   A.", "age": 28},
        ]

        # Create a Spark DataFrame
        original_df = spark.createDataFrame(sample_data)

        # Apply the transformation function from before
        transformed_df = remove_extra_spaces(original_df, "name")

        expected_data = [
            {"name": "John D.", "age": 31},
            {"name": "Alice A.", "age": 25},
            {"name": "Bob T.", "age": 40},
            {"name": "Eve A.", "age": 28},
        ]

        expected_df = spark.createDataFrame(expected_data)

        assertDataFrameEqual(transformed_df, expected_df)

suite = unittest.TestLoader().loadTestsFromTestCase(TestTranformation)
runner = unittest.TextTestRunner(
    verbosity=0,
)
runner.run(suite)
結果
======================================================================
FAIL: test_single_space (__main__.TestTranformation)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/root/.ipykernel/1109/command-2200363102628332-3245114994", line 28, in test_single_space
    assertDataFrameEqual(transformed_df, expected_df)
  File "/databricks/spark/python/pyspark/instrumentation_utils.py", line 48, in wrapper
    res = func(*args, **kwargs)
  File "/databricks/spark/python/pyspark/testing/utils.py", line 548, in assertDataFrameEqual
    assert_rows_equal(df_list, expected_list)
  File "/databricks/spark/python/pyspark/testing/utils.py", line 528, in assert_rows_equal
    raise PySparkAssertionError(
pyspark.errors.exceptions.base.PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 75.00000 % )
--- actual
+++ expected

- Row(age=25, name='Alice G.')
?                         ^

+ Row(age=25, name='Alice A.')
?                         ^

********************

- Row(age=30, name='John D.')
?          ^

+ Row(age=31, name='John D.')
?          ^

********************

- Row(age=35, name='Bob T.')
?         ^^

+ Row(age=40, name='Bob T.')
?         ^^

********************


----------------------------------------------------------------------
Ran 1 test in 0.720s

FAILED (failures=1)

unittestでも問題なくassertDataFrameEqualが動作してますね。
結果もexpected dataとの差異をわかるように出力してくれています。

その他

pytestでも試しましたが、同様に動きました。
ドキュメントにも記載されているように、特定のテストモジュールに依存していないので好きなテストフレームワークで使うことができますね。

まとめ

こういう機能は地味ですが実務において役に立つので、実装されたのは嬉しい。

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?