Scala ラムダ超入門

  • 43
    いいね
  • 0
    コメント
この記事は最終更新日から1年以上が経過しています。

Scalaの文法に慣れて来た方を対象に、ラムダ式や高階関数を使って関数を取り回す方法を説明します。カリー化や部分適用も取り上げます。いわゆる関数型言語らしい機能です。

この記事は以下の記事のScala版です。故に、結構無理しているところもあります。

練習の解答例は別記事に掲載します。

ラムダ式

今まで取り上げて来た文法では、関数の引数を左辺で定義していました。

inc.scala
def inc(x: Int) = x + 1

println(inc(5))
実行結果
6

引数を右辺で定義する文法があります。

inc_lambda.scala
val inc: Int => Int = x => x + 1

println(inc(5))
実行結果
6

この右辺をラムダ式と呼びます。

戻り値の型を省略することも出来ます。これは型推論で戻り値を決めているからで、戻り値の型を指定しないといけない場合もあります。

inc_lambda.scala
val inc = (x: Int) => x + 1

本章では型推論できない場合は型を指定します。

練習

【問1】次に示す関数factをラムダ式で書き換えてください。

factorial.scala
def fact(n: Int): Int = {
  n match {
    case 0 => 1
    case n if n > 0 => n * fact(n - 1)
  }
}
println(fact(5))
println(fact(0))

解答例

型注釈

型注釈とラムダ式を並べると、型注釈の書式がラムダ式と共通していることが分かります。

inc_annotation.scala
// 型注釈とラムダ式を並べると、型注釈の書式がラムダ式と共通していることが分かります
// が、ScalaはHaskellと違ってラムダ式と型注釈が一緒になっています
// 引数と戻り値の書き方が、本文と一致します
val inc: Int => Int
       = x   => x + 1

println(inc(0))
実行結果
1

練習

【問2】次に示す関数addをラムダ式で書き換えてください。

add.scala
def add(x:Int, y:Int): Int = x + y

println(add(2, 3))
println(add(9, 3))

解答例

無名関数

ラムダ式は名前のない関数(無名関数)で、それを変数に束縛していると捉えることができます。

lambda_function_with_variable.scala
val a = 1
val b: Int => Int = x => x + 1

println(b(a))
実行結果
2

ラムダ式を束縛しないで使うこともできます。

lambda_function_without_variable.scala
println(((x => x + 1): Int => Int)(1))
実行結果
2

一度しか使わない関数にわざわざ名前を付けるのが面倒なとき、ラムダ式は便利です。

複数の引数

次の2つは同じ型(Int, Int) => Intです。

mul.scala
def mul1 (x: Int, y: Int): Int = x * y
val mul2: (Int, Int) => Int = (x, y) => x * y

println(mul1(2, 3))
println(mul2(2, 3))

練習

【問3】次に示す関数addを定義せずに、呼び出し側で無名関数にインライン展開してください。

lambda_function_add.scala
def add(x: Int, y: Int): Int = x + y

println(add(2, 3))

解答例

高階関数

引数として関数を受け取ったり、戻り値として関数を返したりする関数を高階関数と呼びます。

引数

引数として関数を受け取る高階関数の例です。

higher_order_with_variable.scala
def f(g: (Int, Int) => Int) = g(2, 3)

def add(x: Int, y: Int): Int = x + y
val mul: (Int, Int) => Int = (x, y) => x * y

// メソッドでもラムダ式でもどちらでも渡せる
println(f(add))
println(f(mul))
実行結果
5
6

上のmulはラムダ式が変数に束縛され、その変数をfに渡しています。変数を経由せずにラムダ式を直接渡すこともできます。

higher_order_without_variable.scala
def f(g: (Int, Int) => Int) = g(2, 3)

println(f((x, y) => x + y))
println(f((x, y) => x * y))
実行結果
5
6

戻り値

戻り値として関数を返す高階関数の例です。

higher_order_return_value.scala
def add(x: Int): Int => Int = y => x + y
val add2 = add(2)

println(add2(3))
println((add(2))(3))
println(add(2)(3))
実行結果
5
5
5

上のaddはラムダ式(y: Int) => x + yで表される無名関数を返す高階関数です。高階関数から戻された関数に後続の引数を渡すことで連続して呼び出すことができます。

練習

【問4】次に示す関数faddを定義せずに、呼び出し側で無名関数にインライン展開してください。

higher_order_parameter_with_function.scala
def f(g: (Int, Int) => Int) = g(1, 2)
def add(x: Int, y: Int): Int = x + y

println(f(add))

解答例

【問5】次に示す関数addを定義せずに、呼び出し側で無名関数にインライン展開してください。

add_return_value.scala
def add(x: Int): Int => Int = y => x + y

println(add(1)(2))

解答例

カリー化

複数の引数を取る関数に対して、引数を後ろから1つずつ右辺に移動させてみます。

curry.scala
def add1(x: Int, y: Int) = x + y
def add2(x: Int): Int => Int = y => x + y
def add3: Int => Int => Int = x => y => x + y

println(add1(2,3))
println(add2(2)(3))
println(add3(2)(3))
実行結果
5
5
5

add2add3は同じように振る舞うため、定義は等価だと見なせます。このように引数を1つずつ分割して関数をネストさせることをカリー化と呼びます。Haskellでは複数の引数を取る関数は自動的にカリー化されます。

部分適用

引数が足りない場合、後で付け足せば呼び出しを完成させることができます。このようなことが可能になるのも関数がカリー化されているためです。

partial_application.scala
def add: Int => Int => Int = x => y => x + y
val add2 = add(2)

println(add(2)(3))
println(add2(3))
実行結果
5
5
5

上のadd2のように一部の引数を固定化して新しい関数を作り出すことを部分適用と呼びます。

カリー化されていない関数の部分適用

カリー化されていない関数を部分適用するには、引数をワイルドカード_で指定します。型注釈は省略できません。

def add(x: Int, y: Int) = x + y
val add2 = add(2, _: Int)

println(add(2,3))
println(add2(3))
println((add(2, _: Int))(3))

この構文はラムダ式の糖衣構文で、ラッパーで包むことにより部分適用を実現しています。

比較
val add2 = add(2, _: Int)
val add3 = (x: Int) => add(2, x)

注意点

カリー化と部分適用は混同されることがあります。具体的には、部分適用を指してカリー化と呼ばれることがありますが、これは誤用です。

改めて定義を確認します。

  • カリー化とは:関数を引数1つずつに分割してネストさせること
  • 部分適用とは:一部の引数を固定化して新しい関数を作り出すこと

練習

【問6】次に示す関数combineを、引数1つずつに分割してネストさせたラムダ式で書き換えてください。

combine.scala
def combine[A](a: A, b: A, c: A) = a +: b +: List(c)

val a = combine(1, _: Int, _: Int)
val b = a(2, _: Int)
val c = b(3)
println(c)
println(combine('a', 'b', 'c'))

解答例

【問7】次のコードから関数doubleを除去してください。ラムダ式は使わないでください。

double.scala
def f(xs: List[Int], g: Int => Int) = for (x <- xs) yield g(x)
def double (x: Int) = 2 * x

println(f(List(1, 2, 3, 4, 5), double))

ヒント: _

解答例

演算子

演算子のラッパーを作れば高階関数に渡すことができます。

operator.scala
def f(g: (Int, Int) => Int) = g(2, 3)

println(f(_ + _))
println(f(_ * _))
実行結果
5
6

セクション

オペランド(被演算子)を片方だけワイルドカードにすることもできます。(問7で出題)

operator_partial_application.scala
def f(g: Int => Int) = g(5)

println(f(2 - _))
println(f(2 * _))
実行結果
-3

このように片方のオペランド(被演算子)を省略した不完全な式をセクションと呼びます。

ラムダ式とセクションを対比します。

section.scala
def f(g: Int => Int) = g(5)

println((f(x => 2 + x), f(2 + _)))
println((f(x => x + 2), f(_ + 2)))
println((f(x => 2 - x), f(2 - _)))
println((f(x => x - 2), f(_ - 2)))
実行結果
(7, 7)
(7, 7)
(-3, -3)
(3, 3)

練習

【問8】次のコードからラムダ式を排除してください。新しい関数を定義してはいけません。

operator_with_lambda.scala
def f1(g: Int => Int) = g(1)
def f2(g: (Int, Int) => Int) = g(2, 3)

println(f1(x => x - 3))
println(f1(x => 3 - x))
println(f2(x, y) => x + y))

解答例

色々な関数

高階関数をいくつか紹介します。

map

リストの要素すべてに同じ処理を施した別のリストを作成します。

同じことができるリスト内包表記と対比します。

map.scala
println(List(1, 2, 3, 4, 5).map(x => x * 2))
println(for (x <- List(1, 2, 3, 4, 5)) yield x * 2)
実行結果
List(2, 4, 6, 8, 10)
List(2, 4, 6, 8, 10)

リスト内包表記との使い分けは特に基準はありませんが、高階関数の扱いに慣れればリスト内包表記が冗長に感じるかもしれません。

filter

リストから要素を取り出す際に条件を指定できます。

同じことができるリスト内包表記と対比します。

filter.scala
println(List(1,2,3,4,5,6,7,8,9).filter(_ < 5))
println(for (x <- List(1, 2, 3, 4, 5) if x < 5) yield x)
実行結果
List(1, 2, 3, 4)
List(1, 2, 3, 4)

flip

Haskellには2引数関数で引数の順序を反転するflipという関数がありますが、Scalaにはflipに相当する関数はありませんので、関数を定義します。

第2引数への部分適用をセクションやラッパーと対比します。

flip.scala
def flip[A, B, C](f: A => B => C)(x: B)(y: A) = f(y)(x)
def append: String => String => String = a => a + _
def sub: Int => Int => Int = a => a - _

val flipped = flip(append)(_)
println(flipped("foo")("bar"))
println(flip(append)("foo")("bar"))

println(flip(sub)(5)(3))
println(flip(sub)(3)(5))
実行結果
barfoo
barfoo
-2
2

第1引数がラムダ式で記述すると長くなる時、先に第2引数を記述するために使うと便利です。Yコンビネータで実例を出します。

foldLeft

リストの要素を左から1つずつ処理しながら集計します。

foldLeftsum相当の処理をしてみました。

foldleft.scala
println((1 to 100).sum)
println((1 to 100).foldLeft(0)((z, n) => z + n))
実行結果
5050
5050

foldLeftは手続型言語のループを関数化したものだと見なせます。次のJavaScriptコードと比較してみてください。

JavaScript
var sum = 0;
for (var i = 1; i <= 100; ++i) {
    sum += i;
}
console.log(sum);
実行結果
5050

集計の初期値として0を用意して、ループでiを次々に足していきます。そこから処理(足し算)と初期値を取り出して関数化したのがfoldLeftだと見なすわけです。

foldRight

リストの要素を右から1つずつ処理しながら集計します。

println((1 to 5).foldRight(0)((z, n) => z - n))
println((1 to 5).foldLeft(0)((z, n) => z - n))
実行結果
3
-15

再帰でリストの全要素を処理する際に、再帰から返って来た値を使って関数の戻り値を計算すると、戻り値が確定するのは再帰の復路です。この手の再帰を関数化したものだと見なせます。次のコードと比較してみてください。

foldright_recursion.scala
def test(x: List[Int]): Int = {
  x match {
    case m if m.isEmpty => 0
    case x: List[Int] => x.head - test(x.tail)
  }
}

println(test(List(1, 2, 3, 4, 5)))
実行結果
3

再帰の折り返し値として0を用意して、復路でxから次々に引いていきます。そこから処理(引き算)と折り返し値(最初に作用させる値)を取り出して関数化したのがfoldRightだと見なすわけです。

  • 1 - (2 - (3 - (4 - (5 - 0))))1 - 2 + 3 - 4 + 53

※ 復路で実際の計算が始まるため、計算の順はリストの右からとなります。これがRightの意味です。

練習

【問9】map, filter, flip, foldl, foldrを再帰で再実装してください。関数名にはmを付けてください。

具体的には以下のコードが動くようにしてください。

recursion.scala
println(mmap((_*2), (1 to 5)))
println(mfilter((x => x < 5), (1 to 9)))
println(mflip(mmap)(1 to 5)(_*2))
println(mfoldLeft((_+_), 0, (1 to 100)))
println(mfoldLeft((_-_), 0, (1 to 5)))
println(mfoldRight((_-_), 0, (1 to 5)))
実行結果
List(2, 4, 6, 8, 10)
List(1, 2, 3, 4)
List(2, 4, 6, 8, 10)
5050
-15
3

解答例

【問10】foldLeftreversemaximumminimumを再実装してください。関数名にはmを付けてください。

具体的には以下のコードが動くようにしてください。

foldleft_reverse_max_min.scala
println(mreverse((-5 to 5)))
println(maximum((-5 to 5)))
println(minimum((-5 to 5)))
実行結果
List(5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5)
5
-5

解答例

【問11】次に示す関数qsortfilterで書き替えてください。引数にはList型のリストを使用してください。

qsort.scala
def qsort(xs: Array[Int]): Array[Int] = {
  def swap(i: Int, j: Int) {
    val t = xs(i); xs(i) = xs(j); xs(j) = t
  }
  def sort1(l: Int, r: Int) {
    val pivot = xs((l + r) / 2)
    var i = l; var j = r
    while (i <= j) {
      while (xs(i) < pivot) i += 1
      while (xs(j) > pivot) j -= 1
      if (i <= j) {
        swap(i, j)
        i += 1
        j -= 1
      }
    }
    if (l < j) sort1(l, j)
    if (j < r) sort1(i, r)
  }
  sort1(0, xs.length - 1)
  xs
}

val array = qsort(Array(4, 6, 9, 8, 3, 5, 1, 7, 2))
array.foreach(println)

このクイックソートのプログラムはこちらを参考にしました。

最初の例

解答例

【問12】次に示すバブルソートの関数bswapfoldReftで書き替えてください。

bsort.scala
def bsort(list: List[Int]): List[Int] = {
  def bswap(xs: List[Int]): List[Int] = {
    xs match {
      case xs if xs.length == 1 => List(xs.head)
      case xs => {
        lazy val ys = bswap(xs.tail)
        if (xs.head > ys.head) ys.head :: xs.head :: ys.tail
        else xs.head :: ys.head :: ys.tail
      }
    }
  }

  lazy val ys = bswap(list)
  list match {
    case list if list.isEmpty => List()
    case list => ys.head :: bsort(ys.tail)
  }
}

println(bsort(List(4,3,1,5,2)))
実行結果
List(1, 2, 3, 4, 5)

解答例

この問題は次の記事を参考にしました。

不動点コンビネータ

自己参照のできない無名のラムダ式で再帰を実現するテクニックとして、不動点コンビネータを利用する方法があります。あまり使う機会はないかもしれませんが、たまに見掛けるので知識として知っておいても損はないでしょう。

Yコンビネータ

Yコンビネータ(不動点コンビネータの一種)と呼ばれる補助関数を定義します。

y.scala
def Y[A,B]( f:((A => B), A ) => B, x:A ):B = f( ( y:A ) => Y( f,y ),x )

Yコンビネータにラムダ式を渡すと、ラムダ式の第1引数にYコンビネータに包まれた自分自身が渡されます。これを使うことで再帰ができます。

フィボナッチ数をインラインで実装した例です。ラムダ式の第1引数を関数名に見立てています。

fibonacci_y.scala
def Y[A,B]( f:((A => B), A ) => B, x:A ):B = f( ( y:A ) => Y( f,y ),x )

val y = Y( (f:Int => Int , n:Int) => n match {
  case m if m < 2 => m
  case _ => f( n - 1 ) + f( n - 2 )
}, 10)

println(y)
実行結果
55

Yコンビネータは次の記事を参考にしました。

練習

【問13】Yコンビネータを使って関数sumを再実装してください。

解答例

関数合成

関数を連続して呼ぶのに対して、関数を合成しても同じことができます。

comp_func.scala
def f(x: Int): Int = x + 1
def g(x: Int): Int = x * 2

println(f(g(1)))
println((f _).andThen(g)(1))
println((f _).compose(g)(1))
実行結果
3
4
3

関数を使うだけだとあまり違いは感じませんが、高階関数に渡すときに関数合成は便利です。

higher_comp_func.scala
def f(x: Int): Int = x + 1
def g(x: Int): Int = x * 2
def h(f: (Int) => Int): Int = f(1)

println(f(g(1)))
println(f(h(g)))
println(h((x: Int) => f(g(x))))
println((h _).andThen(f)(g))
println((f _).compose(h)(g))
実行結果
3
3
3
3
3

2引数

gの引数が増えるといきなり難しくなります。

comp_func_two_args.scala
def f(x: Int): Int = x * 2
def g(x: Int, y: Int): Int = x + y

println(f(g(1, 2)))
// println((g _).compose(f)(1)(2))
実行結果
6

Scalaの関数合成については、こちらを参考にしました。
* 関数合成

ポイントフリースタイル

部分適用を利用すれば、別の関数に渡すだけの引数を省略できます。これをポイントフリースタイルと呼びます。

簡単な例を示します。

pointfree_sub.scala
def f1(x: Int, y: Int): Int = x - y
def f2(x: Int): Int = f1(3, x)
def f3 = f1(3, _:Int)
def f4 = f1(_:Int, _:Int)

println(f1(3, 2))
println(f2(2))
println(f3(2))
println(f4(3, 2))

def g1(x: Int): Int = f4(2, x)
def g2 = f4(2, _:Int)

println(g1(5))
println(g2(5))
実行結果
1
1
1
1
-3
-3

関数合成と組み合わせれば色々なパターンで引数を排除することができます。ただあまり突き詰めると、すぐには読めないコードになるような印象があります。

もっと恐ろしいものの片鱗を味わいたい方は、次の記事を参照してみてください。