shapelessのNatを調べてみた。
Natとは
Natは、型レベルの自然数を表すもの。型で数字を表す。
0は、_0
という特殊な型がある。
1は、Succ[_0]
という型で表す。
2は、Succ[Succ[_0]]
という型で表す。
3は、Succ[Succ[Succ[_0]]]
という型で(ry
という感じでSuccを使って表す。
ペアノの公理とかいうやつらしい。
試してみる
import shapeless._
scala> Nat.toInt[_0]
res12: Int = 0
scala> Nat.toInt[Succ[_0]]
res13: Int = 1
scala> Nat.toInt[Succ[Succ[_0]]]
res14: Int = 2
scala> Nat.toInt[Succ[Succ[Succ[_0]]]]
res15: Int = 3
Succの数が数字になっている。
四則演算やら
足し算
import ops.nat._
scala> val sum = Sum[Succ[Succ[_0]], Succ[_0]] // 2 + 1
sum: shapeless.ops.nat.Sum[shapeless.Succ[shapeless.Succ[shapeless._0]],shapeless.Succ[shapeless._0]]{type Out = shapeless.Succ[shapeless.Succ[shapeless.Succ[shapeless._0]]]} = shapeless.ops.nat$Sum$$anon$5@15f6b98e
scala> Nat.toInt[sum.Out]
res21: Int = 3
引き算
scala> val diff = Diff[Succ[Succ[_0]], Succ[_0]] // 2 - 1
diff: shapeless.ops.nat.Diff[shapeless.Succ[shapeless.Succ[shapeless._0]],shapeless.Succ[shapeless._0]]{type Out = shapeless.Succ[shapeless._0]} = shapeless.ops.nat$Diff$$anon$7@3705577d
scala> Nat.toInt[diff.Out]
res22: Int = 1
かけ算(Prod)と割り算(Div)、MinやMax、比較のLTなどもある。
このように型レベルで計算ができるので、型レベルのフィボナッチ数列の計算やクイックソートができたりする。
型レベルのフィボナッチ数列の計算をやってみる
↓をベースに実装。(rubyなのはググッたら最初に出たからで意味はない)
def fib(n)
if n <= 1
return n
end
fib(n - 2) + fib(n - 1)
end
↓が実装。コメントもつけた。
import shapeless._
import shapeless.ops.nat.LTEq.<=
import shapeless.ops.nat.{Sum, ToInt}
/**
* @tparam N ベースにしたコードのnに相当
* @tparam Out 結果の数
*/
class Fib[N <: Nat, Out <: Nat] {
type Res = Out // 外から結果が見えるように
}
object Fib {
/**
* @param n フィボナッチ数列のn項。IntからNatへのimplicitが提供されているのでIntでも渡せる。
*/
def apply[Out <: Nat](n: Nat)(implicit fib: Fib[n.N, Out]) = fib
/**
* if (n <= 1) n の実装。
*/
implicit def fib0[N <: Nat](implicit n: N <= Succ[_0]) = new Fib[N, N]
/**
* fib(n - 2) + fib(n - 1) の実装。
*
* @tparam N 再帰で減らしていく数。Succ[Succ[N]]なのでNは-2したもの。
* @tparam L fib(n - 2)の結果に相当。
* @tparam R fib(n - 1)の結果に相当。Succ[N]を付けることで-2に+1している。
* @return LとRを足したものを結果として返す。
*/
implicit def fib1[N <: Nat, L <: Nat, R <: Nat](implicit l: Fib[N, L], r: Fib[Succ[N], R], sum: Sum[L, R]) =
new Fib[Succ[Succ[N]], sum.Out]
}
テストしてみる。
// チェック用の関数を定義
def check(fib: Fib[_, _], n1: Nat)(implicit ev: fib.Res =:= n1.N) {}
check(Fib(0), 0)
check(Fib(1), 1)
check(Fib(2), 1)
check(Fib(3), 2)
check(Fib(4), 3)
check(Fib(5), 5)
check(Fib(6), 8)
check(Fib(7), 13)
check(Fib(0), 1) // マッチしないとコンパイルエラー
// <console>:66: error: Cannot prove that shapeless._0 =:= shapeless.Succ[shapeless._0].
// check(Fib(0), 1)
階乗も
実装してみた。
3の階乗だったら、321=6 みたいなやつ。
class Fac[N <: Nat, Out <: Nat] {
type Res = Out
}
object Fac {
def apply[Out <: Nat](n: Nat)(implicit fac: Fac[n.N, Out]) = fac
implicit def fac1 = new Fac[Succ[_0], Succ[_0]]
implicit def fac2[N <: Nat, NN <: Nat](implicit fac: Fac[N, NN], prod: Prod[Succ[N], NN]) =
new Fac[Succ[N], prod.Out]
}
def check(fac: Fac[_, _], n1: Nat)(implicit ev: fac.Res =:= n1.N) {}
check(Fac(1), 1)
check(Fac(2), 2)
check(Fac(3), 6)
check(Fac(4), 24)
check(Fac(5), 120)
5の階乗は120なので3分くらい計算に時間がかかる\( 'ω')/
このように型レベルの計算が色々できる。
コンパイル時に数値のチェックができるので、知っておけば何かしら使いどころが見つかる...かもしれない? :(;゙゚'ω゚'):
HListのapplyで
HListのapplyでn番目の値を取る時にも使われていて、n番目がなければコンパイルエラーになる。
scala> val xs = 1 :: "a" :: HNil
xs: shapeless.::[Int,shapeless.::[String,shapeless.HNil]] = 1 :: a :: HNil
scala> xs(0)
res122: Int = 1
scala> xs(1)
res123: String = a
scala> xs(2)
<console>:54: error: Implicit not found: shapeless.Ops.At[shapeless.::[Int,shapeless.::[String,shapeless.HNil]], nat_$macro$68.N]. You requested to access an element at the position nat_$macro$68.N, but the HList shapeless.::[Int,shapeless.::[String,shapeless.HNil]] is too short.
xs(2)
ソースを読んでみると、applyを呼ぶと、Atを読むようだ。Atでは、HListをループしつつ、applyで渡したnを-1していき、nが0になったらそのときのHListのHeadが返るようだ。
というわけでNat便利ですね。(∩´ᵕ`∩)