Edited at

SparkのDataFrame/Datasetをテストするユーティリティ関数


方針

Spark の Dataset を使った集計処理の単体テストを書きたいとします。たとえば、

case class RowA(a: String, b: Long, c: Double)

という case class があったとして、

val left = Seq(

RowA("x", 1, 1.1),
RowA("x", 2, 2.2),
RowA("x", 3, 3.3),
RowA("y", 4, 4.4),
RowA("y", 5, 5.5)
).toDS().groupBy("a").agg(sum("b") as "b", sum("c") as "c").as[RowA]

GROUP BY + SUM 処理をした結果が、

val right = Seq(

RowA("x", 6, 6.6),
RowA("y", 9, 9.9)
).toDS().as[RowA]

と一致する、という単体テストを書きたいです。

ここでテストフレームワークとして ScalaTest を使っているならば、

left shouldEqual right

と一致判定したいところですね。ところが、これは


[a: string, b: bigint ... 1 more field] did not equal [a: string, b: bigint ... 1 more field]


となってしまいます。なぜなら shouldEqual は Equality 型クラスに依存していて、ここで参照されるデフォルトの Equality はインスタンスとしての一致判定をしてしまうからです。

ということで Dataset のための Equality 型クラスのインスタンスを実装したくなります。ただ、この方向性はあまり実用的ではありません。というのも Equality 型クラスではエラーメッセージを扱わないため Dataset のどこで具体的な差異が出ているのか伝えられないのですね。

そこで方針を変えて Dataset についての判定もろもろを一纏めにしたユーティリティ関数を作ることにします。なお type DataFrame = Dataset[Row] なので、それで DataFrame もカバーできます。


判定の流れ

まずレコード数が合っているかどうか確認しましょう。

withClue("Dataset count:") {

left.count() shouldEqual right.count()
}

withClue はエラーメッセージの先頭に追記することができます。そのため、これに引っかかった場合は、


Dataset count: 2 did not equal 3


というエラーメッセージを出せるので分かりやすいですね。

次にスキーマが合っているかどうか確認しましょう。

withClue("Dataset schema:") {

// アグリゲーションなどの際に意図せず nullable になるため nullable は比較しない
val lSchema = left.schema.fields.map(f => (f.name, f.dataType))
val rSchema = right.schema.fields.map(f => (f.name, f.dataType))

lSchema shouldEqual rSchema
}

エラーメッセージはこのようになります。


Dataset schema: Array((a,StringType), (b,LongType), (c,DoubleType)) did not equal Array((a,StringType), (b,LongType), (c,DoubleType), (d,BooleanType))


そして最後に各レコードの値一つ一つが合っているか判定したいのですが、 Spark は集計結果のレコード順序を一般的に保証しません。そのため left と right を同じようにソートする必要があります。

val columns = left.schema.fields.map(f => new ColumnName(f.name))

val lSortedRows = left.sort(columns: _*).toDF().collect()
val rSortedRows = right.sort(columns: _*).toDF().collect()

あとはレコード順に値を判定していけばいいのですが、ここで気をつけなければならないのが浮動小数点数の誤差です。たとえば、

1.2 * 3 shouldEqual 3.6


3.5999999999999996 did not equal 3.6


ということが起こるので、許容誤差を

implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.01)

のように決めてやる必要があります。

よって値の判定はこのようになります。

for ((l, r) <- lSortedRows.zip(rSortedRows)) {

withClue(s"Dataset row: $l did not equal $r:") {
for (field <- l.schema.fields) field.dataType match {
case FloatType =>
shouldEqualField[Float](l, r, field)

case DoubleType =>
shouldEqualField[Double](l, r, field)

case _ =>
shouldEqualField[Any](l, r, field)
}
}
}

private def shouldEqualField[B: Equality](left: Row, right: Row, field: StructField): Unit = {
// TolerantNumerics.tolerantDoubleEquality などは null と null の比較で false を返すが Dataset の比較テストでは通したい
val lv = left.getAs[B](field.name)
val rv = right.getAs[B](field.name)
if (lv != null || rv != null) {
lv shouldEqual rv
}
}

エラーメッセージは、


Dataset row: [x,6,6.6] did not equal [x,6,6.8]: 6.6 did not equal 6.8


となるので Dataset のどのレコードで差異があるか一目瞭然ですね。


ユーティリティ関数

ということで、まとめると次のようなユーティリティ関数になります。

レコード数の一致、スキーマの一致、浮動小数点数の誤差を加味した値の一致、を一気に判定します。必要に応じてカスタマイズしてください。


src/test/scala/SparkTestUtil.scala

import org.apache.spark.sql._

import org.apache.spark.sql.types._
import org.scalactic.Equality
import org.scalatest.Matchers._

object SparkTestUtil {

implicit class DatasetHelper[A](val left: Dataset[A]) extends AnyVal {
def shouldEqualDataset(right: Dataset[A])(implicit floatEq: Equality[Float], doubleEq: Equality[Double]): Unit = {

withClue("Dataset count:") {
left.count() shouldEqual right.count()
}

withClue("Dataset schema:") {
// アグリゲーションなどの際に意図せず nullable になるため nullable は比較しない
val lSchema = left.schema.fields.map(f => (f.name, f.dataType))
val rSchema = right.schema.fields.map(f => (f.name, f.dataType))

lSchema shouldEqual rSchema
}

val columns = left.schema.fields.map(f => new ColumnName(f.name))
val lSortedRows = left.sort(columns: _*).toDF().collect()
val rSortedRows = right.sort(columns: _*).toDF().collect()

for ((l, r) <- lSortedRows.zip(rSortedRows)) {
withClue(s"Dataset row: $l did not equal $r:") {
for (field <- l.schema.fields) field.dataType match {
case FloatType =>
shouldEqualField[Float](l, r, field)

case DoubleType =>
shouldEqualField[Double](l, r, field)

case _ =>
shouldEqualField[Any](l, r, field)
}
}
}
}

private def shouldEqualField[B: Equality](left: Row, right: Row, field: StructField): Unit = {
// TolerantNumerics.tolerantDoubleEquality などは null と null の比較で false を返すが Dataset の比較テストでは通したい
val lv = left.getAs[B](field.name)
val rv = right.getAs[B](field.name)
if (lv != null || rv != null) {
lv shouldEqual rv
}
}
}

}


このユーティリティ関数を使ったテストサンプルです。


src/test/scala/SparkTestUtilTest.scala

import SparkTestUtil.DatasetHelper

import org.apache.spark.sql.functions.sum
import org.scalactic.{Equality, TolerantNumerics}
import org.scalatest.{DoNotDiscover, FunSuite, Matchers}

case class RowA(a: String, b: Long, c: Double)

case class RowB(a: String, b: Long, c: Double)

@DoNotDiscover
class SparkTestUtilTest extends FunSuite with Matchers with SparkTest {

test("shouldEqualDataset") {
import spark.implicits._

implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.01)

val left = Seq(
RowA("x", 1, 1.1),
RowA("x", 2, 2.2),
RowA("x", 3, 3.3),
RowA("y", 4, 4.4),
RowA("y", 5, 5.5)
).toDS().groupBy("a").agg(sum("b") as "b", sum("c") as "c").as[RowA]

val right = Seq(
RowB("y", 9, 9.909),
RowB("x", 6, 6.6)
).toDS().as[RowA]

left shouldEqualDataset right
}
}



src/test/scala/SparkTest.scala

import org.apache.spark.sql._

import org.scalatest.{BeforeAndAfterAll, Suites}

trait SparkTest {
val spark: SparkSession = SparkSession.builder().master("local[*]").getOrCreate()
}

class SparkTests
extends Suites(
new SparkTestUtilTest
)
with SparkTest
with BeforeAndAfterAll {

override def afterAll(): Unit = {
spark.stop()
}

}



動作確認環境


build.sbt

name := "spark-test"

version := "0.1.0"
javacOptions ++= Seq("-source", "11", "-target", "11")
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked")
scalaVersion := "2.12.8"

libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % "2.4.3",
"org.apache.spark" %% "spark-sql" % "2.4.3",
"org.scalatest" %% "scalatest" % "3.0.8" % "test"
)



project/build.properties

sbt.version=1.2.8


$ java -version

openjdk version "11.0.2" 2019-01-15
OpenJDK Runtime Environment AdoptOpenJDK (build 11.0.2+9)
OpenJDK 64-Bit Server VM AdoptOpenJDK (build 11.0.2+9, mixed mode)