Help us understand the problem. What is going on with this article?

SparkのDataFrame/Datasetをテストするユーティリティ関数

More than 1 year has passed since last update.

方針

Spark の Dataset を使った集計処理の単体テストを書きたいとします。たとえば、

case class RowA(a: String, b: Long, c: Double)

という case class があったとして、

val left = Seq(
  RowA("x", 1, 1.1),
  RowA("x", 2, 2.2),
  RowA("x", 3, 3.3),
  RowA("y", 4, 4.4),
  RowA("y", 5, 5.5)
).toDS().groupBy("a").agg(sum("b") as "b", sum("c") as "c").as[RowA]

GROUP BY + SUM 処理をした結果が、

val right = Seq(
  RowA("x", 6, 6.6),
  RowA("y", 9, 9.9)
).toDS().as[RowA]

と一致する、という単体テストを書きたいです。
ここでテストフレームワークとして ScalaTest を使っているならば、

left shouldEqual right

と一致判定したいところですね。ところが、これは

[a: string, b: bigint ... 1 more field] did not equal [a: string, b: bigint ... 1 more field]

となってしまいます。なぜなら shouldEqual は Equality 型クラスに依存していて、ここで参照されるデフォルトの Equality はインスタンスとしての一致判定をしてしまうからです。

ということで Dataset のための Equality 型クラスのインスタンスを実装したくなります。ただ、この方向性はあまり実用的ではありません。というのも Equality 型クラスではエラーメッセージを扱わないため Dataset のどこで具体的な差異が出ているのか伝えられないのですね。
そこで方針を変えて Dataset についての判定もろもろを一纏めにしたユーティリティ関数を作ることにします。なお type DataFrame = Dataset[Row] なので、それで DataFrame もカバーできます。

判定の流れ

まずレコード数が合っているかどうか確認しましょう。

withClue("Dataset count:") {
  left.count() shouldEqual right.count()
}

withClue はエラーメッセージの先頭に追記することができます。そのため、これに引っかかった場合は、

Dataset count: 2 did not equal 3

というエラーメッセージを出せるので分かりやすいですね。

次にスキーマが合っているかどうか確認しましょう。

withClue("Dataset schema:") {
  // アグリゲーションなどの際に意図せず nullable になるため nullable は比較しない
  val lSchema = left.schema.fields.map(f => (f.name, f.dataType))
  val rSchema = right.schema.fields.map(f => (f.name, f.dataType))

  lSchema shouldEqual rSchema
}

エラーメッセージはこのようになります。

Dataset schema: Array((a,StringType), (b,LongType), (c,DoubleType)) did not equal Array((a,StringType), (b,LongType), (c,DoubleType), (d,BooleanType))

そして最後に各レコードの値一つ一つが合っているか判定したいのですが、 Spark は集計結果のレコード順序を一般的に保証しません。そのため left と right を同じようにソートする必要があります。

val columns = left.schema.fields.map(f => new ColumnName(f.name))
val lSortedRows = left.sort(columns: _*).toDF().collect()
val rSortedRows = right.sort(columns: _*).toDF().collect()

あとはレコード順に値を判定していけばいいのですが、ここで気をつけなければならないのが浮動小数点数の誤差です。たとえば、

1.2 * 3 shouldEqual 3.6

3.5999999999999996 did not equal 3.6

ということが起こるので、許容誤差を

implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.01)

のように決めてやる必要があります。
よって値の判定はこのようになります。

for ((l, r) <- lSortedRows.zip(rSortedRows)) {
  withClue(s"Dataset row: $l did not equal $r:") {
    for (field <- l.schema.fields) field.dataType match {
      case FloatType =>
        shouldEqualField[Float](l, r, field)

      case DoubleType =>
        shouldEqualField[Double](l, r, field)

      case _ =>
        shouldEqualField[Any](l, r, field)
    }
  }
}

private def shouldEqualField[B: Equality](left: Row, right: Row, field: StructField): Unit = {
  // TolerantNumerics.tolerantDoubleEquality などは null と null の比較で false を返すが Dataset の比較テストでは通したい
  val lv = left.getAs[B](field.name)
  val rv = right.getAs[B](field.name)
  if (lv != null || rv != null) {
    lv shouldEqual rv
  }
}

エラーメッセージは、

Dataset row: [x,6,6.6] did not equal [x,6,6.8]: 6.6 did not equal 6.8

となるので Dataset のどのレコードで差異があるか一目瞭然ですね。

ユーティリティ関数

ということで、まとめると次のようなユーティリティ関数になります。
レコード数の一致、スキーマの一致、浮動小数点数の誤差を加味した値の一致、を一気に判定します。必要に応じてカスタマイズしてください。

src/test/scala/SparkTestUtil.scala
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.scalactic.Equality
import org.scalatest.Matchers._

object SparkTestUtil {

  implicit class DatasetHelper[A](val left: Dataset[A]) extends AnyVal {
    def shouldEqualDataset(right: Dataset[A])(implicit floatEq: Equality[Float], doubleEq: Equality[Double]): Unit = {

      withClue("Dataset count:") {
        left.count() shouldEqual right.count()
      }

      withClue("Dataset schema:") {
        // アグリゲーションなどの際に意図せず nullable になるため nullable は比較しない
        val lSchema = left.schema.fields.map(f => (f.name, f.dataType))
        val rSchema = right.schema.fields.map(f => (f.name, f.dataType))

        lSchema shouldEqual rSchema
      }

      val columns = left.schema.fields.map(f => new ColumnName(f.name))
      val lSortedRows = left.sort(columns: _*).toDF().collect()
      val rSortedRows = right.sort(columns: _*).toDF().collect()

      for ((l, r) <- lSortedRows.zip(rSortedRows)) {
        withClue(s"Dataset row: $l did not equal $r:") {
          for (field <- l.schema.fields) field.dataType match {
            case FloatType =>
              shouldEqualField[Float](l, r, field)

            case DoubleType =>
              shouldEqualField[Double](l, r, field)

            case _ =>
              shouldEqualField[Any](l, r, field)
          }
        }
      }
    }

    private def shouldEqualField[B: Equality](left: Row, right: Row, field: StructField): Unit = {
      // TolerantNumerics.tolerantDoubleEquality などは null と null の比較で false を返すが Dataset の比較テストでは通したい
      val lv = left.getAs[B](field.name)
      val rv = right.getAs[B](field.name)
      if (lv != null || rv != null) {
        lv shouldEqual rv
      }
    }
  }

}

このユーティリティ関数を使ったテストサンプルです。

src/test/scala/SparkTestUtilTest.scala
import SparkTestUtil.DatasetHelper
import org.apache.spark.sql.functions.sum
import org.scalactic.{Equality, TolerantNumerics}
import org.scalatest.{DoNotDiscover, FunSuite, Matchers}

case class RowA(a: String, b: Long, c: Double)

case class RowB(a: String, b: Long, c: Double)

@DoNotDiscover
class SparkTestUtilTest extends FunSuite with Matchers with SparkTest {

  test("shouldEqualDataset") {
    import spark.implicits._

    implicit val doubleEquality: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(0.01)

    val left = Seq(
      RowA("x", 1, 1.1),
      RowA("x", 2, 2.2),
      RowA("x", 3, 3.3),
      RowA("y", 4, 4.4),
      RowA("y", 5, 5.5)
    ).toDS().groupBy("a").agg(sum("b") as "b", sum("c") as "c").as[RowA]

    val right = Seq(
      RowB("y", 9, 9.909),
      RowB("x", 6, 6.6)
    ).toDS().as[RowA]

    left shouldEqualDataset right
  }
}
src/test/scala/SparkTest.scala
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterAll, Suites}

trait SparkTest {
  val spark: SparkSession = SparkSession.builder().master("local[*]").getOrCreate()
}

class SparkTests
    extends Suites(
      new SparkTestUtilTest
    )
    with SparkTest
    with BeforeAndAfterAll {

  override def afterAll(): Unit = {
    spark.stop()
  }

}

動作確認環境

build.sbt
name := "spark-test"
version := "0.1.0"
javacOptions ++= Seq("-source", "11", "-target", "11")
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked")
scalaVersion := "2.12.8"

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % "2.4.3",
  "org.apache.spark" %% "spark-sql" % "2.4.3",
  "org.scalatest" %% "scalatest" % "3.0.8" % "test"
)
project/build.properties
sbt.version=1.2.8
$ java -version
openjdk version "11.0.2" 2019-01-15
OpenJDK Runtime Environment AdoptOpenJDK (build 11.0.2+9)
OpenJDK 64-Bit Server VM AdoptOpenJDK (build 11.0.2+9, mixed mode)
piyo7
機械の言葉と、数学の言葉と、それから人間の言葉を喋るよ♪
https://piyo7.github.io/deep-sister/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away