Scalaの文法に慣れて来た方を対象に、ラムダ式や高階関数を使って関数を取り回す方法を説明します。カリー化や部分適用も取り上げます。いわゆる関数型言語らしい機能です。
この記事は以下の記事のScala版です。故に、結構無理しているところもあります。
練習の解答例は別記事に掲載します。
ラムダ式
今まで取り上げて来た文法では、関数の引数を左辺で定義していました。
def inc(x: Int) = x + 1
println(inc(5))
6
引数を右辺で定義する文法があります。
val inc: Int => Int = x => x + 1
println(inc(5))
6
この右辺をラムダ式と呼びます。
戻り値の型を省略することも出来ます。これは型推論で戻り値を決めているからで、戻り値の型を指定しないといけない場合もあります。
val inc = (x: Int) => x + 1
本章では型推論できない場合は型を指定します。
練習
【問1】次に示す関数fact
をラムダ式で書き換えてください。
def fact(n: Int): Int = {
n match {
case 0 => 1
case n if n > 0 => n * fact(n - 1)
}
}
println(fact(5))
println(fact(0))
⇒ 解答例
型注釈
型注釈とラムダ式を並べると、型注釈の書式がラムダ式と共通していることが分かります。
// 型注釈とラムダ式を並べると、型注釈の書式がラムダ式と共通していることが分かります
// が、ScalaはHaskellと違ってラムダ式と型注釈が一緒になっています
// 引数と戻り値の書き方が、本文と一致します
val inc: Int => Int
= x => x + 1
println(inc(0))
1
練習
【問2】次に示す関数add
をラムダ式で書き換えてください。
def add(x:Int, y:Int): Int = x + y
println(add(2, 3))
println(add(9, 3))
⇒ 解答例
無名関数
ラムダ式は名前のない関数(無名関数)で、それを変数に束縛していると捉えることができます。
val a = 1
val b: Int => Int = x => x + 1
println(b(a))
2
ラムダ式を束縛しないで使うこともできます。
println(((x => x + 1): Int => Int)(1))
2
一度しか使わない関数にわざわざ名前を付けるのが面倒なとき、ラムダ式は便利です。
複数の引数
次の2つは同じ型(Int, Int) => Int
です。
def mul1 (x: Int, y: Int): Int = x * y
val mul2: (Int, Int) => Int = (x, y) => x * y
println(mul1(2, 3))
println(mul2(2, 3))
練習
【問3】次に示す関数add
を定義せずに、呼び出し側で無名関数にインライン展開してください。
def add(x: Int, y: Int): Int = x + y
println(add(2, 3))
⇒ 解答例
高階関数
引数として関数を受け取ったり、戻り値として関数を返したりする関数を高階関数と呼びます。
引数
引数として関数を受け取る高階関数の例です。
def f(g: (Int, Int) => Int) = g(2, 3)
def add(x: Int, y: Int): Int = x + y
val mul: (Int, Int) => Int = (x, y) => x * y
// メソッドでもラムダ式でもどちらでも渡せる
println(f(add))
println(f(mul))
5
6
上のmul
はラムダ式が変数に束縛され、その変数をf
に渡しています。変数を経由せずにラムダ式を直接渡すこともできます。
def f(g: (Int, Int) => Int) = g(2, 3)
println(f((x, y) => x + y))
println(f((x, y) => x * y))
5
6
戻り値
戻り値として関数を返す高階関数の例です。
def add(x: Int): Int => Int = y => x + y
val add2 = add(2)
println(add2(3))
println((add(2))(3))
println(add(2)(3))
5
5
5
上のadd
はラムダ式(y: Int) => x + y
で表される無名関数を返す高階関数です。高階関数から戻された関数に後続の引数を渡すことで連続して呼び出すことができます。
練習
【問4】次に示す関数f
とadd
を定義せずに、呼び出し側で無名関数にインライン展開してください。
def f(g: (Int, Int) => Int) = g(1, 2)
def add(x: Int, y: Int): Int = x + y
println(f(add))
⇒ 解答例
【問5】次に示す関数add
を定義せずに、呼び出し側で無名関数にインライン展開してください。
def add(x: Int): Int => Int = y => x + y
println(add(1)(2))
⇒ 解答例
カリー化
複数の引数を取る関数に対して、引数を後ろから1つずつ右辺に移動させてみます。
def add1(x: Int, y: Int) = x + y
def add2(x: Int): Int => Int = y => x + y
def add3: Int => Int => Int = x => y => x + y
println(add1(2,3))
println(add2(2)(3))
println(add3(2)(3))
5
5
5
add2
とadd3
は同じように振る舞うため、定義は等価だと見なせます。このように引数を1つずつ分割して関数をネストさせることをカリー化と呼びます。Haskellでは複数の引数を取る関数は自動的にカリー化されます。
部分適用
引数が足りない場合、後で付け足せば呼び出しを完成させることができます。このようなことが可能になるのも関数がカリー化されているためです。
def add: Int => Int => Int = x => y => x + y
val add2 = add(2)
println(add(2)(3))
println(add2(3))
5
5
5
上のadd2
のように一部の引数を固定化して新しい関数を作り出すことを部分適用と呼びます。
カリー化されていない関数の部分適用
カリー化されていない関数を部分適用するには、引数をワイルドカード_
で指定します。型注釈は省略できません。
def add(x: Int, y: Int) = x + y
val add2 = add(2, _: Int)
println(add(2,3))
println(add2(3))
println((add(2, _: Int))(3))
この構文はラムダ式の糖衣構文で、ラッパーで包むことにより部分適用を実現しています。
val add2 = add(2, _: Int)
val add3 = (x: Int) => add(2, x)
注意点
カリー化と部分適用は混同されることがあります。具体的には、部分適用を指してカリー化と呼ばれることがありますが、これは誤用です。
改めて定義を確認します。
- カリー化とは:関数を引数1つずつに分割してネストさせること
- 部分適用とは:一部の引数を固定化して新しい関数を作り出すこと
練習
【問6】次に示す関数combine
を、引数1つずつに分割してネストさせたラムダ式で書き換えてください。
def combine[A](a: A, b: A, c: A) = a +: b +: List(c)
val a = combine(1, _: Int, _: Int)
val b = a(2, _: Int)
val c = b(3)
println(c)
println(combine('a', 'b', 'c'))
⇒ 解答例
【問7】次のコードから関数double
を除去してください。ラムダ式は使わないでください。
def f(xs: List[Int], g: Int => Int) = for (x <- xs) yield g(x)
def double (x: Int) = 2 * x
println(f(List(1, 2, 3, 4, 5), double))
ヒント: _
⇒ 解答例
演算子
演算子のラッパーを作れば高階関数に渡すことができます。
def f(g: (Int, Int) => Int) = g(2, 3)
println(f(_ + _))
println(f(_ * _))
5
6
セクション
オペランド(被演算子)を片方だけワイルドカードにすることもできます。(問7で出題)
def f(g: Int => Int) = g(5)
println(f(2 - _))
println(f(2 * _))
-3
このように片方のオペランド(被演算子)を省略した不完全な式をセクションと呼びます。
ラムダ式とセクションを対比します。
def f(g: Int => Int) = g(5)
println((f(x => 2 + x), f(2 + _)))
println((f(x => x + 2), f(_ + 2)))
println((f(x => 2 - x), f(2 - _)))
println((f(x => x - 2), f(_ - 2)))
(7, 7)
(7, 7)
(-3, -3)
(3, 3)
練習
【問8】次のコードからラムダ式を排除してください。新しい関数を定義してはいけません。
def f1(g: Int => Int) = g(1)
def f2(g: (Int, Int) => Int) = g(2, 3)
println(f1(x => x - 3))
println(f1(x => 3 - x))
println(f2(x, y) => x + y))
⇒ 解答例
色々な関数
高階関数をいくつか紹介します。
map
リストの要素すべてに同じ処理を施した別のリストを作成します。
同じことができるリスト内包表記と対比します。
println(List(1, 2, 3, 4, 5).map(x => x * 2))
println(for (x <- List(1, 2, 3, 4, 5)) yield x * 2)
List(2, 4, 6, 8, 10)
List(2, 4, 6, 8, 10)
リスト内包表記との使い分けは特に基準はありませんが、高階関数の扱いに慣れればリスト内包表記が冗長に感じるかもしれません。
filter
リストから要素を取り出す際に条件を指定できます。
同じことができるリスト内包表記と対比します。
println(List(1,2,3,4,5,6,7,8,9).filter(_ < 5))
println(for (x <- List(1, 2, 3, 4, 5) if x < 5) yield x)
List(1, 2, 3, 4)
List(1, 2, 3, 4)
flip
Haskellには2引数関数で引数の順序を反転するflipという関数がありますが、Scalaにはflipに相当する関数はありませんので、関数を定義します。
第2引数への部分適用をセクションやラッパーと対比します。
def flip[A, B, C](f: A => B => C)(x: B)(y: A) = f(y)(x)
def append: String => String => String = a => a + _
def sub: Int => Int => Int = a => a - _
val flipped = flip(append)(_)
println(flipped("foo")("bar"))
println(flip(append)("foo")("bar"))
println(flip(sub)(5)(3))
println(flip(sub)(3)(5))
barfoo
barfoo
-2
2
第1引数がラムダ式で記述すると長くなる時、先に第2引数を記述するために使うと便利です。Yコンビネータで実例を出します。
foldLeft
リストの要素を左から1つずつ処理しながら集計します。
foldLeft
でsum
相当の処理をしてみました。
println((1 to 100).sum)
println((1 to 100).foldLeft(0)((z, n) => z + n))
5050
5050
foldLeft
は手続型言語のループを関数化したものだと見なせます。次のJavaScriptコードと比較してみてください。
var sum = 0;
for (var i = 1; i <= 100; ++i) {
sum += i;
}
console.log(sum);
5050
集計の初期値として0
を用意して、ループでi
を次々に足していきます。そこから処理(足し算)と初期値を取り出して関数化したのがfoldLeft
だと見なすわけです。
foldRight
リストの要素を右から1つずつ処理しながら集計します。
println((1 to 5).foldRight(0)((z, n) => z - n))
println((1 to 5).foldLeft(0)((z, n) => z - n))
3
-15
再帰でリストの全要素を処理する際に、再帰から返って来た値を使って関数の戻り値を計算すると、戻り値が確定するのは再帰の復路です。この手の再帰を関数化したものだと見なせます。次のコードと比較してみてください。
def test(x: List[Int]): Int = {
x match {
case m if m.isEmpty => 0
case x: List[Int] => x.head - test(x.tail)
}
}
println(test(List(1, 2, 3, 4, 5)))
3
再帰の折り返し値として0
を用意して、復路でx
から次々に引いていきます。そこから処理(引き算)と折り返し値(最初に作用させる値)を取り出して関数化したのがfoldRight
だと見なすわけです。
-
1 - (2 - (3 - (4 - (5 - 0))))
→1 - 2 + 3 - 4 + 5
→3
※ 復路で実際の計算が始まるため、計算の順はリストの右からとなります。これがRightの意味です。
練習
【問9】map
, filter
, flip
, foldl
, foldr
を再帰で再実装してください。関数名にはm
を付けてください。
具体的には以下のコードが動くようにしてください。
println(mmap((_*2), (1 to 5)))
println(mfilter((x => x < 5), (1 to 9)))
println(mflip(mmap)(1 to 5)(_*2))
println(mfoldLeft((_+_), 0, (1 to 100)))
println(mfoldLeft((_-_), 0, (1 to 5)))
println(mfoldRight((_-_), 0, (1 to 5)))
List(2, 4, 6, 8, 10)
List(1, 2, 3, 4)
List(2, 4, 6, 8, 10)
5050
-15
3
⇒ 解答例
【問10】foldLeft
でreverse
とmaximum
とminimum
を再実装してください。関数名にはm
を付けてください。
具体的には以下のコードが動くようにしてください。
println(mreverse((-5 to 5)))
println(maximum((-5 to 5)))
println(minimum((-5 to 5)))
List(5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5)
5
-5
⇒ 解答例
【問11】次に示す関数qsort
をfilter
で書き替えてください。引数にはList型のリストを使用してください。
def qsort(xs: Array[Int]): Array[Int] = {
def swap(i: Int, j: Int) {
val t = xs(i); xs(i) = xs(j); xs(j) = t
}
def sort1(l: Int, r: Int) {
val pivot = xs((l + r) / 2)
var i = l; var j = r
while (i <= j) {
while (xs(i) < pivot) i += 1
while (xs(j) > pivot) j -= 1
if (i <= j) {
swap(i, j)
i += 1
j -= 1
}
}
if (l < j) sort1(l, j)
if (j < r) sort1(i, r)
}
sort1(0, xs.length - 1)
xs
}
val array = qsort(Array(4, 6, 9, 8, 3, 5, 1, 7, 2))
array.foreach(println)
このクイックソートのプログラムはこちらを参考にしました。
⇒ 解答例
【問12】次に示すバブルソートの関数bswap
をfoldReft
で書き替えてください。
def bsort(list: List[Int]): List[Int] = {
def bswap(xs: List[Int]): List[Int] = {
xs match {
case xs if xs.length == 1 => List(xs.head)
case xs => {
lazy val ys = bswap(xs.tail)
if (xs.head > ys.head) ys.head :: xs.head :: ys.tail
else xs.head :: ys.head :: ys.tail
}
}
}
lazy val ys = bswap(list)
list match {
case list if list.isEmpty => List()
case list => ys.head :: bsort(ys.tail)
}
}
println(bsort(List(4,3,1,5,2)))
List(1, 2, 3, 4, 5)
⇒ 解答例
この問題は次の記事を参考にしました。
- @kazu_yamamoto: リストの畳み込みと展開 - あどけない話 2011.9.13
不動点コンビネータ
自己参照のできない無名のラムダ式で再帰を実現するテクニックとして、不動点コンビネータを利用する方法があります。あまり使う機会はないかもしれませんが、たまに見掛けるので知識として知っておいても損はないでしょう。
Yコンビネータ
Yコンビネータ(不動点コンビネータの一種)と呼ばれる補助関数を定義します。
def Y[A,B]( f:((A => B), A ) => B, x:A ):B = f( ( y:A ) => Y( f,y ),x )
Yコンビネータにラムダ式を渡すと、ラムダ式の第1引数にYコンビネータに包まれた自分自身が渡されます。これを使うことで再帰ができます。
フィボナッチ数をインラインで実装した例です。ラムダ式の第1引数を関数名に見立てています。
def Y[A,B]( f:((A => B), A ) => B, x:A ):B = f( ( y:A ) => Y( f,y ),x )
val y = Y( (f:Int => Int , n:Int) => n match {
case m if m < 2 => m
case _ => f( n - 1 ) + f( n - 2 )
}, 10)
println(y)
55
Yコンビネータは次の記事を参考にしました。
- @kazu_yamamoto: Haskell で Y コンビネータ - あどけない話 2010.5.19
- @kazu_yamamoto: Yコンビネータのまとめ - あどけない話 2010.6.1
- @yuroyoro: scalaでYコンビネータ
練習
【問13】Yコンビネータを使って関数sum
を再実装してください。
⇒ 解答例
関数合成
関数を連続して呼ぶのに対して、関数を合成しても同じことができます。
def f(x: Int): Int = x + 1
def g(x: Int): Int = x * 2
println(f(g(1)))
println((f _).andThen(g)(1))
println((f _).compose(g)(1))
3
4
3
関数を使うだけだとあまり違いは感じませんが、高階関数に渡すときに関数合成は便利です。
def f(x: Int): Int = x + 1
def g(x: Int): Int = x * 2
def h(f: (Int) => Int): Int = f(1)
println(f(g(1)))
println(f(h(g)))
println(h((x: Int) => f(g(x))))
println((h _).andThen(f)(g))
println((f _).compose(h)(g))
3
3
3
3
3
2引数
g
の引数が増えるといきなり難しくなります。
def f(x: Int): Int = x * 2
def g(x: Int, y: Int): Int = x + y
println(f(g(1, 2)))
// println((g _).compose(f)(1)(2))
6
Scalaの関数合成については、こちらを参考にしました。
ポイントフリースタイル
部分適用を利用すれば、別の関数に渡すだけの引数を省略できます。これをポイントフリースタイルと呼びます。
簡単な例を示します。
def f1(x: Int, y: Int): Int = x - y
def f2(x: Int): Int = f1(3, x)
def f3 = f1(3, _:Int)
def f4 = f1(_:Int, _:Int)
println(f1(3, 2))
println(f2(2))
println(f3(2))
println(f4(3, 2))
def g1(x: Int): Int = f4(2, x)
def g2 = f4(2, _:Int)
println(g1(5))
println(g2(5))
1
1
1
1
-3
-3
関数合成と組み合わせれば色々なパターンで引数を排除することができます。ただあまり突き詰めると、すぐには読めないコードになるような印象があります。
もっと恐ろしいものの片鱗を味わいたい方は、次の記事を参照してみてください。
- @melponn: ポイントフリースタイル入門 - melpon日記 2011.10.31