LoginSignup
14
15

More than 5 years have passed since last update.

Scala の Stream pattern matching でハマった話

Posted at

TL;DR

正格評価なのか非正格評価(遅延評価)なのかはしっかり把握しましょう、という話。
あとおまけで、素数無限列挙の話についてもちらっと。

初めに

Scala は、遅延リストとしての Stream を備えています。
Stream.from(n) とすれば n 以降の整数を無限に列挙1できますし、その他フィボナッチ数列や素数列も定義できます。もちろん遅延リストなので、基本的には実際の値は必要になるまで算出されません。
また Haskell などと同様、Scala でも Stream は『最初の要素』と『残り』のようなパターンマッチングに対応しています。これを利用して柔軟な Stream 処理が可能なのですが…。
ちょっと込み入ったことをしようとすると、すぐに StackOverflowError に見舞われてしまいます。
ということで、自分用の覚書を兼ねて、その原因と対策をまとめてみます2

Scala の Stream

初〜中級者向けのおさらいです。知ってる方は読み飛ばしてください。

Scala では、Stream クラスによって Stream の構築や各種処理が可能です。よくある mapfilter などもそのクラスのメソッドとして用意されています3
また Haskell の : 演算子と同様に、#:: を利用して、最初の要素とそれ以降の Stream を結合(右結合)することで Stream を構築することもできます。右側は遅延評価されるので、再帰的定義も可能です。

例:フィボナッチ数列

scala> val fibs:Stream[BigInt] = 0 #:: fibs.zip(1 #:: fibs).map{t => t._1+t._2}
fibs: Stream[BigInt] = Stream(0, ?)

scala> fibs.take(10).toList
res0: List[BigInt] = List(0, 1, 1, 2, 3, 5, 8, 13, 21, 34)

scala> fibs(10)
res1: BigInt = 55

scala> fibs  // ↓算出済の要素までは文字列表現中に列挙される
res2: Stream[BigInt] = Stream(0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, ?)

また Scala の Stream はパターンマッチングにも標準対応しています。
これも Haskell と同様、#:: がそのままコンストラクタパターンとして利用可能です。

scala> val a #:: b #:: c = fibs
a: BigInt = 0
b: BigInt = 1
c: scala.collection.immutable.Stream[BigInt] = Stream(1, 2, 3, 5, 8, 13, 21, 34, 55, ?)

#:: パターンマッチングの NG 例

(狭義)単調増加列である2つの Stream が与えられて、それを昇順を保ったままマージする(同じ値は1つの要素にまとめる)ことを考えましょう。
こんな実装がすぐに思いつきます4

def merge[A <% Ordered[A]](a: Stream[A], b: Stream[A]): Stream[A] = 
  // ** Both `a` and `b` assume to be strictly increasing. **
  (a, b) match {
    case (x #:: xs, y #:: _) if x < y => x #:: merge(xs, b)
    case (x #:: _, y #:: ys) if x > y => y #:: merge(a, ys)
    case (x #:: xs, _ #:: ys) => x #:: merge(xs, ys)
    case (_, Stream.Empty) => a
    case _ => b
  }

これはこれで、実装としてあながち間違いではなく、普通に使う分には問題は発生せず期待通りに動作します。

scala> merge(fibs, ((3:BigInt) to (10:BigInt)).toStream).take(20).toList
res3: List[BigInt] = List(0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 13, 21, 34, 55, 89, 144, 233, 377)

ただしこれを利用して、再帰的定義をしようとすると、問題が発生する場合があります。
例えば、ハミング数5を列挙することを意図した以下のコード:

scala> val hamming: Stream[BigInt] = 1 #:: merge(hamming.map{_*2}, 
     |   merge(hamming.map{_*3}, hamming.map{_*5}))
hamming: Stream[BigInt] = Stream(1, ?)

2つめ以降の値を取得しようとした途端に java.lang.StackOverflowError で落ちます。

scala> hamming(0)  // これはOK
res4: BigInt = 1

scala> hamming(1)
java.lang.StackOverflowError
// 以下延々とスタックトレースが出力される

ちなみに。
Haskell で同様な実装を行ってもエラーは発生せず期待通りに動作します。

merge :: Ord a => [a] -> [a] -> [a]
merge a@(x:xs) b@(y:ys)
    | x < y = x : merge xs b
    | x > y = y : merge a ys
    | otherwise = x : merge xs ys
merge a [] = a
merge [] b = b

hamming :: [Integer]
hamming = 1 : merge (map (*2) hamming) (merge (map (*3) hamming) (map (*5) hamming))

hamming !! 1
-- => 2

take 20 hamming
-- => [1,2,3,4,5,6,8,9,10,12,15,16,18,20,24,25,27,30,32,36]

Haskell と Scala のこの違いは、以下の2点から来るものです:

  • Haskell は基本 遅延評価(非正格評価)なので、x:xs というパターンで受け取った xs もその具体的な値が必要になるまで評価されないが、Scala では x #:: xs というパターンで受け取った xs は正格評価されてしまう。
  • Scala の Stream は、head(最初の要素)は正格(tail(2つめ以降のStream)は非正格(遅延))になっている。その結果、a.tail が評価されると、その時点でその最初の要素(=元のStreamの2番目の要素)が正格評価されてしまう。

つまり、ある Stream a に対する x #:: xs というパターンマッチは、(x, xs) = (a.head, a.tail) と同じ意味になり、xsa.tail が束縛されようとした時点で、その最初の要素(=a.tail.head)が評価されてしまいます。このときその時点でまだ具体的な値が定まっていなければ、定義に従ってその値を算出しようとします。hamming の場合、その結果定義に従って再帰を無限に辿ろうとしてしまって、StackOverflow、というわけです。

解決例

今回は「とにかくまともに動くかつそこそこシンプルな merge の実装を得る」ことが目的。なので、以下のように実装変更すればOKです。

def merge[A <% Ordered[A]](a: Stream[A], b: Stream[A]): Stream[A] = 
  // ** Both `a` and `b` assume to be strictly increasing. **
  (a.headOption, b.headOption) match {
    case (Some(x), Some(y)) if x < y => x #:: merge(a.tail, b)
    case (Some(x), Some(y)) if x > y => y #:: merge(a, b.tail)
    case (Some(x), Some(_)) => x #:: merge(a.tail, b.tail)
    case (_, None) => a
    case _ => b
  }

つまり。
「パターンマッチの時は head だけ見る」これだけ。
普通にやったら tail は(少なくとも tail の head が)正格評価されてしまうので、「どう遅延させるか」を考えて工夫するのも良い6のですが、「だったらマッチング時に tail を評価させなければOK」ということ。
だって結局 tail の内容が必要になるのは、マージの次のステップ以降だけなのですから。「#::を利用したパターンマッチング」にこだわる必要なんてないんですから。

一応、ハミング数の実装で確認↓

scala> val hamming: Stream[BigInt] = 1 #:: merge(hamming.map{_*2}, 
     |   merge(hamming.map{_*3}, hamming.map{_*5}))
hamming: Stream[BigInt] = Stream(1, ?)

scala> hamming(1)
res5: BigInt = 2

scala> hamming.take(20).toList
res6: List[BigInt] = List(1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36)

ちゃんと期待通り動きました。

まとめ

Scala は基本が「正格評価」。一部明示的に「遅延評価」と指示されたコード片(lazy val 〜 とか、a #:: ss とか7)以外は、正格評価されてしまうと把握してコーディングすべし。
StackOverflowError が発生したら、そのあたりで無限再帰をやらかしている可能性があるので、気をつけて見直すべし8

おまけ:本来の意味での「エラトステネスの篩」

「Scala 素数列挙」で検索すると、やっぱりよくある「filterを使った再帰的実装」が出てきますし、それを「エラトステネスの篩の実装例」って紹介している記事さえありますけれど、あれ、正確には「エラトステネスの篩」じゃない、ってどれくらいの方が把握しているんでしょう?
あれはただの「試し割り法」です。だから無茶苦茶列挙遅いし、ちょっと大きめの(例えば10000番目の)素数を取得しようとすると、すぐに OutOfMemory とか StackOverflow になっちゃいます(運良く取得できてもかなり時間かかります)。

本来の「エラトステネスの篩を利用した素数無限列挙」の実装は、こうです9

primes.scala
object Primes {
  def merge[A <% Ordered[A]](a: Stream[A], b: Stream[A]): Stream[A] = 
    // ** Both `a` and `b` assume to be strictly increasing. **
    (a.headOption, b.headOption) match {
      case (Some(x), Some(y)) if x < y => x #:: merge(a.tail, b)
      case (Some(x), Some(y)) if x > y => y #:: merge(a, b.tail)
      case (Some(x), Some(_)) => x #:: merge(a.tail, b.tail)
      case (_, None) => a
      case _ => b
    }

  def joinL[A <% Ordered[A]](xs: Stream[Stream[A]]): Stream[A] = 
    xs.head.head #:: merge(xs.head.tail, joinL(xs.tail))

  def minus[A <% Ordered[A]](a: Stream[A], b: Stream[A]): Stream[A] = 
    // ** Both `a` and `b` assume to be strictly increasing. **
    (a.headOption, b.headOption) match {
      case (Some(x), Some(y)) if x < y => x #:: minus(a.tail, b)
      case (Some(x), Some(y)) if x > y => minus(a, b.tail)
      case (Some(_), Some(_)) => minus(a.tail, b.tail)
      case _ => a
    }

  val primes: Stream[BigInt] = 2 #:: minus(Stream.iterate(3:BigInt)(_+1), 
          joinL(primes.map{p => Stream.iterate(p*p)(_+p)}))

  def main(args: Array[String]): Unit = {
    printExecutionTime { println(primes.take(100).mkString(", ")) }
    // => 2, 3, 5, 7, … , 541
    // => 18msec

    printExecutionTime { println(primes(9999)) }
    // => 104729
    // => 150msec
  }

  private def printExecutionTime(proc: => Unit) = {
    val start = System.currentTimeMillis
    proc
    println((System.currentTimeMillis - start) + "msec")
  }
}

ideone.com での実行結果
main 以降は、ワンソースで動作確認までするためのおまけです。気にしないでください)

先ほどの merge に加え、merge を再帰的に無限に適用する joinL と、merge と同様の発想で(狭義)単調増加列どうしの diff をとる minus の2関数を追加定義し、
「2以上の整数で、最初の数の(2倍以上の)倍数をすべて排除し、残った数列の次の数以降で同じことを繰り返す」
という、エラトステネスの篩の本来のアルゴリズムを再現したコードになっています。10000番目の素数も1秒未満で算出できます。

この実装を応用してもっと工夫すれば、もっともっと高速に無限列挙できる実装10もあるのですが、それはまた機会があれば。


  1. ただし Stream.from(n: Int): Stream[Int] つまり 32bit 整数にしか対応していないので 2147483648 以上の整数はこの方法では列挙できません。 

  2. 話の内容としては、「正格評価関数型プログラミングあるある」のようなので、知ってる人には当たり前な内容なのかも。 

  3. 実際にはそのほとんどが、祖先型である Seq トレイトで宣言または基本実装されており、それを継承・実装またはオーバーライドしています。 

  4. すべてパターンマッチに任せるのではなく、if〜else〜 を利用した方がきっと効率は良いのでしょうけれど、それはまた別のお話。 

  5. https://en.wikipedia.org/wiki/Regular_number 参照。 

  6. tail をラップして遅延評価戦略にうまく乗せる解決策も一応提案されています。 http://stackoverflow.com/questions/7492715/pattern-matching-and-infinite-streamsTailWrapper を導入した解決方法を参照。 

  7. Stream 構築目的の a #:: s という記述時は s は遅延評価。パターンマッチ時の x #:: xsxs は正格評価。これ重要。紛らわしいけれど。 

  8. 定義自体には問題がないからコンパイルエラー出ないんですよね。ホント気をつけないと。 

  9. 元々これを書きたくて mergeminus を自分で実装していて、「あれ、StackOverflowError で落ちる…」て気付いたのが、この記事を書くきかっけだったりします…。 

  10. 例えば「2,3,5 と『2でも3でも5でも割り切れない数』からなる『擬素数列』」を定義して利用することで篩の目を粗くするとか。 

14
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
15