Haskell
Scala
TDD
test
圏論

確率モナドのCats実装とモナド則のDisciplineテスト

『すごいHaskellたのしく学ぼう』の14章8節に、カスタムモナドのサンプルとして「確率モナド」が載っている。これをCatsを用いてScalaで書いてみる。

また本には「できたばかりのモナドが、きちんとモナド則を満たしているか試すことも、とても重要です。」との記述もある。これは Disciplineを使って確認してみる。

準備

依存ライブラリ等は以下。

    compilerPlugin("org.spire-math" %% "kind-projector" % "0.9.4"),
    "org.typelevel" % "cats-core_2.12" % "1.0.0-RC2",
    "org.typelevel" % "cats-laws_2.12" % "1.0.0-RC2",
    "org.scalacheck" %% "scalacheck" % "1.13.5",
    "org.scalactic" %% "scalactic" % "3.0.4",
    "org.scalatest" %% "scalatest" % "3.0.4" % "test",

また同じ型の要素からなるタプル(A, A)に関数f:A=>Rを適用するヘルパをパッケージオブジェクトに書いている。

package object ex01 {
  implicit class T2mapper[A](val t: (A, A)) extends AnyVal {
    def map[R](f: A => R): (R, R) = (f(t._1), f(t._2))
    def foldMap[R, B](f: A => R, g: (R, R) => B): B = g.tupled(t map f)
  }
}

REPLでの挙動。

scala> (1, 2) map (_ + 10)
res0: (Int, Int) = (11,12)
scala> (1, 2) foldMap[Int, Int] (_ * 10, _ + _)
res2: Int = 30

gist :> (A,A) =f=> (R,R) where f:A=>R

データ型

本と同じようにProb型で、確率(というか離散確率分布)を表す。またコインや分数なども定義する。

Coin

sealed traitcase objectで、裏・表を値に持つコイン型を定義した。

sealed trait Coin
case object Heads extends Coin
case object Tails extends Coin

Haskellだとdata Coin = Heads | Tails deriving (Show, Eq)で済むが、Scalaの場合少し冗長になる。

Rational

確率を表す分数(有理数)のデータ型として、本ではHaskellData.RatioRationalを使っているが、ここでは自前でRationalを簡単に実装した。

case class Rational(n: Int, d: Int) {
  def *(r: Rational) = Rational.apply(n * r.n, d * r.d)
}
object Rational {
  def apply(n: Int, d: Int): Rational =
    (n, d) foldMap (_ / gcd(n, d), new Rational(_, _))
  private def gcd(a: Int, b: Int): Int = if (b == 0) a else gcd(b, a % b)
}

既約にできるものは、コンパニオンのapplyの時点で約分している(Rational(9, 15)Rational(3, 5)など)。

Prob

以下が離散確率の分布を表すProbクラス。事象(確率変数)と確率のペアをリストとして持つ。

case class Prob[+A](es: List[Event[A]]) {
  def map[B](f: A => B): Prob[B] = Prob(es.map { case (a, r) => f(a) -> r })
  def flatMap[B](f: A => Prob[B]): Prob[B] = flatten(map(f))
}
object Prob {
  type Event[+A] = (A, Rational)
  def flatten[A](ppa: Prob[Prob[A]]): Prob[A] = Prob( for {
    (pa, r1) <- ppa.es
    (a,  r2) <- pa.es
  } yield a -> r1 * r2)
}

flatMapがこのカスタムモナドの核心の部分で、外側のProbでもっている確率を内側のProbの確率にかけてフラットにしている。

モナドインスタンス

以下が「確率モナド」のインスタンス。

object ProbInstances {
  implicit def probMonad: Monad[Prob] = new Monad[Prob] {
    def flatMap[A, B](pa: Prob[A])(f: A => Prob[B]): Prob[B] = pa flatMap f

    def tailRecM[A, B](a: A)(f: A => Prob[Either[A, B]]): Prob[B] = {
      val buf = List.newBuilder[(B, Rational)]
      @tailrec def go(pes: List[(Prob[Either[A, B]], Rational)]): Unit = pes match {
        case (Prob(e :: es), r0) :: tail => e match {
          case (Right(b), r) => buf += (b -> r * r0) ; go(Prob(es) -> r0 :: tail)
          case (Left(a2), r) => go(f(a2) -> r :: Prob(es) -> r :: tail)
        }
        case (Prob(Nil), _) :: tail => go(tail)
        case Nil                    => ()
      }
      go(pure(f(a)).es)
      Prob(buf.result)
    }
    override def pure[A](a: A): Prob[A] = Prob(List(a -> Rational(1, 1)))
  }
}

CatsMonadの場合、pureflatMapに加えてさらにtailRecMも実装する必要がありややこしくなった。tailRecM自体の解説は割愛するが、後述のDisciplineではtailRecM周りのルールセットも提供されていて、「法則」に則っているか確認できる。

ちなみにScalaztailrecMは、Monadではなく別クラスのBindRecに置かれている。

gist :> Scala version of Prob Monad (from LYAHFGG)

REPLで動作確認

ここまでくると REPLから動かせる。

scala> val coin = Prob(List((Heads, Rational(1, 2)), (Tails, Rational(1, 2))))

scala> val loadedCoin = Prob(List((Heads, Rational(1, 10)), (Tails, Rational(9, 10))))

scala> val result = for {
  a <- coin
  b <- coin
  c <- loadedCoin
} yield List(a, b, c).forall(_ == Tails)

scala> result.es.foreach(println)
(false,Rational(1,40))
(false,Rational(9,40))
(false,Rational(1,40))
(false,Rational(9,40))
(false,Rational(1,40))
(false,Rational(9,40))
(false,Rational(1,40))
(true,Rational(9,40))

本に載ってるghci上の動作と同じように動く。

テストコード

ベースとなるテストツールはScalaTestを使った。以下のようなテストクラスで実行できる。

class ProbMonadSpec extends FunSuite with Discipline {

  implicit def defaultEq[A]: Eq[A] = (x, y) => x == y

  implicit def genRational: Gen[Rational] = for {
    n <- Gen.choose(-3, 3)
    d <- Gen.choose(-3, 3).retryUntil(_ != 0)
  } yield Rational.apply(n, d)
  implicit def arbRational: Arbitrary[Rational] = Arbitrary(genRational)

  implicit def genProb[A](implicit arb: Arbitrary[List[Event[A]]]): Gen[Prob[A]] =
    arb.arbitrary.map(Prob(_))
  implicit def arbProb[A](implicit arb: Arbitrary[A]): Arbitrary[Prob[A]] = Arbitrary(genProb[A])

  implicit def genCoin: Gen[Coin] = Gen.oneOf[Coin](Heads, Tails)
  implicit def cogenCoin: Cogen[Coin] = Cogen { _ match {
    case Heads => 1L
    case Tails => 0L
  }}
  implicit def arbCoin: Arbitrary[Coin] = Arbitrary(genCoin)

  checkAll("Monad[Prob[Coin]]", MonadTests[Prob].monad[Coin, Coin, Coin])
}

最後のcheckAllMonadTests#monadで定義されたルールが検証される。前提となるimplicitは以下のとおり。

  • defaultEq: 期待値と実際値を比較するためのEq
  • genRational/arbRational: 確率Rationalを表す分数の生成。
  • genProb/arbProb: 確率分布Probを表す分数の生成。
  • genCoin/cogenCoin/arbCoin: コインの生成。

gist :> discipline for Prob Monad

実行結果

以下、IntelliJ からの実行のスクショ。

スクリーンショット 2017-12-24 3.42.03.png

27個もあるが、本に載ってる3箇条のモナド則は、~.monad left identity~.monad right identity~.flatMap associativityあたりが対応。

上ですこし触れたとおりtailRecMの実装はけっこう難しいが、例えばスタックセーフになってなければ~.tailRecM stack safetyでエラーになるし、あるいはtailRecM内の計算に矛盾があると
~.tailRecM consistent flatMap~.flatMap from tailRecM consistencyなどが失敗するので、却って安心感がある。

よく見ると11秒もかかっているが、テストケースのサイズを小さくできるかもしれない。ちなみにCoin型でなくIntなどでも正しく動く。

所感

  • 慣れるまではコンパイルを通すことからしてやや難しい。またコンパイルできてもテストコードが悪いのかテスト対象コードが悪いのかすぐには分かりにくく、はじめは苦労するかもしれない。
  • 一旦コンパイルが通って思い通りのテスト結果が得られるようになると、かなり安心感がある。tailRecMの実装などはこれがないと逆に難しかった。

TODO

  • refinedなどを使って、Rationalの分母を非ぜロにしたり、分子≦分母とするような制約を、型で表現するといいかもしれない。Probも確率分布と考えるれば合計値が1になるような制約が必要かもしれない。