方針
Spark の Dataset を使った集計処理の単体テストを書きたいとします。たとえば、
case class RowA(a: String, b: Long, c: Double)
という case class があったとして、
val left = Seq(
RowA("x", 1, 1.1),
RowA("x", 2, 2.2),
RowA("x", 3, 3.3),
RowA("y", 4, 4.4),
RowA("y", 5, 5.5)
).toDS().groupBy("a").agg(sum("b") as "b", sum("c") as "c").as[RowA]
GROUP BY + SUM 処理をした結果が、
val right = Seq(
RowA("x", 6, 6.6),
RowA("y", 9, 9.9)
).toDS().as[RowA]
と一致する、という単体テストを書きたいです。
ここでテストフレームワークとして ScalaTest を使っているならば、
left shouldEqual right
と一致判定したいところですね。ところが、これは
[a: string, b: bigint ... 1 more field] did not equal [a: string, b: bigint ... 1 more field]
となってしまいます。なぜなら shouldEqual
は Equality 型クラスに依存していて、ここで参照されるデフォルトの Equality はインスタンスとしての一致判定をしてしまうからです。
ということで Dataset のための Equality 型クラスのインスタンスを実装したくなります。ただ、この方向性はあまり実用的ではありません。というのも Equality 型クラスではエラーメッセージを扱わないため Dataset のどこで具体的な差異が出ているのか伝えられないのですね。
そこで方針を変えて Dataset についての判定もろもろを一纏めにしたユーティリティ関数を作ることにします。なお type DataFrame = Dataset[Row]
なので、それで DataFrame もカバーできます。
判定の流れ
まずレコード数が合っているかどうか確認しましょう。
withClue("Dataset count:") {
left.count() shouldEqual right.count()
}
withClue
はエラーメッセージの先頭に追記することができます。そのため、これに引っかかった場合は、
Dataset count: 2 did not equal 3
というエラーメッセージを出せるので分かりやすいですね。
次にスキーマが合っているかどうか確認しましょう。
withClue("Dataset schema:") {
// アグリゲーションなどの際に意図せず nullable になるため nullable は比較しない
val lSchema = left.schema.fields.map(f => (f.name, f.dataType))
val rSchema = right.schema.fields.map(f => (f.name, f.dataType))
lSchema shouldEqual rSchema
}
エラーメッセージはこのようになります。
Dataset schema: Array((a,StringType), (b,LongType), (c,DoubleType)) did not equal Array((a,StringType), (b,LongType), (c,DoubleType), (d,BooleanType))
そして最後に各レコードの値一つ一つが合っているか判定したいのですが、 Spark は集計結果のレコード順序を一般的に保証しません。そのため left と right を同じようにソートする必要があります。
val columns = left.schema.fields.map(f => new ColumnName(f.name))
val lSortedRows = left.sort(columns: _*).toDF().collect()
val rSortedRows = right.sort(columns: _*).toDF().collect()
あとはレコード順に値を判定していけばいいのですが、ここで気をつけなければならないのが浮動小数点数の誤差です。たとえば、
1.2 * 3 shouldEqual 3.6
3.5999999999999996 did not equal 3.6
ということが起こるので、許容誤差を
implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.01)
のように決めてやる必要があります。
よって値の判定はこのようになります。
for ((l, r) <- lSortedRows.zip(rSortedRows)) {
withClue(s"Dataset row: $l did not equal $r:") {
for (field <- l.schema.fields) field.dataType match {
case FloatType =>
shouldEqualField[Float](l, r, field)
case DoubleType =>
shouldEqualField[Double](l, r, field)
case _ =>
shouldEqualField[Any](l, r, field)
}
}
}
private def shouldEqualField[B: Equality](left: Row, right: Row, field: StructField): Unit = {
// TolerantNumerics.tolerantDoubleEquality などは null と null の比較で false を返すが Dataset の比較テストでは通したい
val lv = left.getAs[B](field.name)
val rv = right.getAs[B](field.name)
if (lv != null || rv != null) {
lv shouldEqual rv
}
}
エラーメッセージは、
Dataset row: [x,6,6.6] did not equal [x,6,6.8]: 6.6 did not equal 6.8
となるので Dataset のどのレコードで差異があるか一目瞭然ですね。
ユーティリティ関数
ということで、まとめると次のようなユーティリティ関数になります。
レコード数の一致、スキーマの一致、浮動小数点数の誤差を加味した値の一致、を一気に判定します。必要に応じてカスタマイズしてください。
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.scalactic.Equality
import org.scalatest.Matchers._
object SparkTestUtil {
implicit class DatasetHelper[A](val left: Dataset[A]) extends AnyVal {
def shouldEqualDataset(right: Dataset[A])(implicit floatEq: Equality[Float], doubleEq: Equality[Double]): Unit = {
withClue("Dataset count:") {
left.count() shouldEqual right.count()
}
withClue("Dataset schema:") {
// アグリゲーションなどの際に意図せず nullable になるため nullable は比較しない
val lSchema = left.schema.fields.map(f => (f.name, f.dataType))
val rSchema = right.schema.fields.map(f => (f.name, f.dataType))
lSchema shouldEqual rSchema
}
val columns = left.schema.fields.map(f => new ColumnName(f.name))
val lSortedRows = left.sort(columns: _*).toDF().collect()
val rSortedRows = right.sort(columns: _*).toDF().collect()
for ((l, r) <- lSortedRows.zip(rSortedRows)) {
withClue(s"Dataset row: $l did not equal $r:") {
for (field <- l.schema.fields) field.dataType match {
case FloatType =>
shouldEqualField[Float](l, r, field)
case DoubleType =>
shouldEqualField[Double](l, r, field)
case _ =>
shouldEqualField[Any](l, r, field)
}
}
}
}
private def shouldEqualField[B: Equality](left: Row, right: Row, field: StructField): Unit = {
// TolerantNumerics.tolerantDoubleEquality などは null と null の比較で false を返すが Dataset の比較テストでは通したい
val lv = left.getAs[B](field.name)
val rv = right.getAs[B](field.name)
if (lv != null || rv != null) {
lv shouldEqual rv
}
}
}
}
このユーティリティ関数を使ったテストサンプルです。
import SparkTestUtil.DatasetHelper
import org.apache.spark.sql.functions.sum
import org.scalactic.{Equality, TolerantNumerics}
import org.scalatest.{DoNotDiscover, FunSuite, Matchers}
case class RowA(a: String, b: Long, c: Double)
case class RowB(a: String, b: Long, c: Double)
@DoNotDiscover
class SparkTestUtilTest extends FunSuite with Matchers with SparkTest {
test("shouldEqualDataset") {
import spark.implicits._
implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.01)
val left = Seq(
RowA("x", 1, 1.1),
RowA("x", 2, 2.2),
RowA("x", 3, 3.3),
RowA("y", 4, 4.4),
RowA("y", 5, 5.5)
).toDS().groupBy("a").agg(sum("b") as "b", sum("c") as "c").as[RowA]
val right = Seq(
RowB("y", 9, 9.909),
RowB("x", 6, 6.6)
).toDS().as[RowA]
left shouldEqualDataset right
}
}
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterAll, Suites}
trait SparkTest {
val spark: SparkSession = SparkSession.builder().master("local[*]").getOrCreate()
}
class SparkTests
extends Suites(
new SparkTestUtilTest
)
with SparkTest
with BeforeAndAfterAll {
override def afterAll(): Unit = {
spark.stop()
}
}
動作確認環境
name := "spark-test"
version := "0.1.0"
javacOptions ++= Seq("-source", "11", "-target", "11")
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked")
scalaVersion := "2.12.8"
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % "2.4.3",
"org.apache.spark" %% "spark-sql" % "2.4.3",
"org.scalatest" %% "scalatest" % "3.0.8" % "test"
)
sbt.version=1.2.8
$ java -version
openjdk version "11.0.2" 2019-01-15
OpenJDK Runtime Environment AdoptOpenJDK (build 11.0.2+9)
OpenJDK 64-Bit Server VM AdoptOpenJDK (build 11.0.2+9, mixed mode)