flatMapをマスターする

  • 185
    Like
  • 0
    Comment
More than 1 year has passed since last update.

flatMapが使えるようになって自分のScalaレベルが格段に上がった印象を持っています。
ただ、Scalaをやり始めた頃はflatMapの挙動がどうもすんなり頭に入って来ず苦しんだ経験を持つのですが、結構同じ悩みを聞くことがあったので誰かの役に立てば、と思いまとめてみました。
モナドとかそういうのはおいておいて、flatMapの振る舞いを理解してしっかり使えるようになることが目的です。
※Java8でもflatMapがStreamに追加されていたので本質的な振る舞いは同じかなと思っているのですが役に立たないかもしれません。

基本

flattenとmapが同時に行われるのがflatMapです。
とりあえずこれだけ抑えておけばflatMapはコワクナイ。

flatten機能の確認

まずはあえてmap機能を使わずにflattenの動きだけを確認。

Seq(Seq(1,2,3), Seq(4), Seq(5, 6)) flatMap { x => x }

結果はSeq(1,2,3,4,5,6)となります。
よく似たものとして、

Seq(Seq(1,2,3), Seq(), Seq(5, 6)) flatMap { x => x }

結果はSeq(1,2,3,5,6)となります。
空Seqはflattenにより打ち消されていることが分かると思います。

もちろん、map処理していないので、この時点では下記と同じです。

Seq(Seq(1,2,3), Seq(), Seq(5, 6)) flatten

flatten + map

次にmap機能も試します。

Seq(Seq(1,2,3), Seq(), Seq(5,6)) flatMap { x => 10 +: x }

結果はSeq(10,1,2,3,10,10,5,6)です。
各Seqの先頭要素に10を追加して新たなSeqに変換しています。

まとめると、flatMapは以下の様な動きとなっています。

Seq(Seq(1,2,3), Seq(), Seq(5,6))
 // map(各Seqの先頭要素に10を追加)
Seq(Seq(10,1,2,3), Seq(10), Seq(10,5,6))
 // flatten
Seq(10,1,2,3,10,10,5,6)

理解に苦しんだもの

これだけ見ると特段難しくないように見えますが、自分の経験からいくと主に以下のものがうまく整理出来てなかったがために身について無かったように思います。

その1

Seq(Seq(1,2,3), Seq(), Seq(5,6)) flatMap { x => x.size }

これは、Seq(3, 0, 2)っていう結果を期待して書いたのですが、コンパイルエラーになります。

 found   : Int
 required: scala.collection.GenTraversableOnce[?]
              Seq(Seq(1,2,3), Seq(), Seq(5,6)) flatMap { x => x.size }

flatMap内の関数の戻り値型はscala.collection.GenTraversableOnceなのにIntが返っているから、という典型的な型違いのエラーです。
これは結果が入れ子構造になってないのでflattenのしようがないわけで、こういう時はflatMapじゃなくてmapを使います。
flatMapとmapがこんがらがっていた頃はこのようなミスをたまに、いや正直に告白するとよくしてしまいました。。
Seqが入れ子構造になっているのを見てflatMapだ!と脊髄反射するとこうなります。見るところはそこじゃ無いんですよね。

その2

Seq(Seq("hello", "world"), Seq("good", "morning")) flatMap { x => x.mkString(" ").toUpperCase }

Seq("HELLO WORLD", "GOOD MORNING")になるのかと思いきや、
Seq(H, E, L, L, O, , W, O, R, L, D, G, O, O, D, , M, O, R, N, I, N, G)になります。

あと、よく似たものとして、

Seq(1, 2, 3) flatMap { x => x }

これはコンパイルエラーなのに(まぁ当然)

Seq("hello", "world") flatMap { x => x }

こっちはコンパイルエラーにはならない(あれっ)。

結果はSeq(h, e, l, l, o, w, o, r, l, d)です。

Stringは以下の関係が成り立つことに注意、です。

val x: scala.collection.GenTraversableOnce[Char] = "hello"

"hello"はSeq('h','e','l','l','o')っていうイメージで捉えておけば難しく無いと思います。
Scaladocをちゃんと見ると、Stringのメソッドとコレクション系クラスのメソッドは同じものが多いですね。

その3

Seq(Some(1), Some(2), None, Some(4)) flatMap { x => x } mkString(",")

結果は"1,2,4"です。

その2のStringまではギリギリ想像できると思うのですが、知らないと思いつきにくいのが上記のコード。
もうお分かりだと思いますが、Optionも以下は問題なく通ります。

val x: scala.collection.GenTraversableOnce[Int] = Some(1)

1や2をOptionという箱で包んでいて、flattenによりその箱が無くなるとイメージするとOptionも他のコレクション系クラスと同様にflatMapと組み合わせて使うことが容易になるかと思います。
NoneというのはSeq()、Nilなんかと同じイメージですね。

ちょっとした応用的な話

flatMap内の結果をSomeとNoneにくるんで返すことで後続の処理にて自動的にSomeのものだけ処理を続けることが出来るようになります。

Seq(1,2,3,4) flatMap { x => 
  if(x % 2 == 0) Some(x) else None 
} map { x =>
  x * 2
} foreach { 
  println 
}

結果はSeq(4,8)となり、flatMap以降の関数には奇数は渡ってきません。
上記の例程度ならfilter + mapの組み合わせ(もしくはcollect)でも出来ますが技術の引き出しを多く持っておくことで適材適所で使えるようになってきます。

なお上記は下記のように書くことも出来ます。

for {
  x <- Seq(1, 2, 3, 4) if x % 2 == 0
  y = x * 2
} println(y)

おまけ

先ほどのこのコードはmapで書きかえるべきだと書きましたが、flatMapのまま、期待するSeq(3,0,2)を得ようとするとどうしたらよいでしょうか?

Seq(Seq(1,2,3), Seq(), Seq(5,6)) flatMap { x => x.size }

GenTraversableOnceを返せばいいので、以下のどれでもOKですね。

// その1
Seq(Seq(1,2,3), Seq(), Seq(5,6)) flatMap { x => Seq(x.size) }
// その2
Seq(Seq(1,2,3), Seq(), Seq(5,6)) flatMap { x => List(x.size) }
// その3
Seq(Seq(1,2,3), Seq(), Seq(5,6)) flatMap { x => Set(x.size) }
// その4
Seq(Seq(1,2,3), Seq(), Seq(5,6)) flatMap { x => Some(x.size) }
// その5
Seq(Seq(1,2,3), Seq(), Seq(5,6)) flatMap { x => x.size.toString } map { String.valueOf(_).toInt }

・・・ただ、このケースではやはりmapを素直に使いましょう(笑)。

恒等関数

本記事で何度か書いたこのコードは、

Seq(Seq(1,2,3), Seq(4), Seq(5, 6)).flatMap{ x => x }

Predefで宣言されているidentityという恒等関数を使ってこう書くことも出来ます。

Seq(Seq(1,2,3), Seq(4), Seq(5, 6)).flatMap(identity)