Kotlin
algorithm
再帰
末尾再帰
末尾再帰最適化

【Kotlin】トランポリンで再帰を末尾再帰にする

この記事はデータ構造とアルゴリズム #2 Advent Calendar19日目の記事です。


2018/12/22追記

ArrowというKotlinでの関数型プログラミングを支援するライブラリにトランポリンの実装が有りましたので、Arrowを使った方がいいです。


TL;DR

Kotlinで作ったトランポリンから、トランポリンによる末尾再帰の仕組みについて書きます。


前書き

この記事は以下の記事の続きです。

トランポリンの実装は以下を用います。


Trampoline.kt

package trampoline

//トランポリン
sealed class Trampoline<T>{
//処理の継続
class More<T>(val calc: () -> Trampoline<T>): Trampoline<T>(){
override fun getValue(): T {
return calc().getValue()
}
}
//処理の終了
class Done<T>(private val value: T): Trampoline<T>() {
override fun getValue(): T {
return value
}
}
//値取り出し用のインターフェース
abstract fun getValue(): T
}

//トランポリンのランナー
fun <T> run(func: () -> Trampoline<T>): T = run(func()).let {
when (it) {
is Trampoline.Done -> it.getValue()
else -> throw Exception()
}
}
//再帰処理本体
private tailrec fun <T> run(trampoline: Trampoline<T>): Trampoline<T> = when (trampoline) {
is Trampoline.Done -> trampoline
is Trampoline.More -> run(trampoline.calc())
}



トランポリンを使う

まず、前書きに載せたトランポリンを用いて再帰的に剰余計算を行ってみます。

コメントに書いたとおり、mod_rec関数をそのまま実行1するとStackOverflowErrorを起こしますが、mod関数を介して実行すると正常に計算ができます。


トランポリンを使った再帰的な剰余計算

import trampoline.Trampoline

fun mod_rec(n: Int, m:Int): Trampoline<Int> {
if(n < m) return Trampoline.Done(n)
return Trampoline.More{ mod_rec(n-m, m) }
}
fun mod(n: Int, m: Int): Int{
return trampoline.run { mod_rec(n, m) }
}

fun main(args: Array<String>) {
println(mod(50000, 3)) //正常に計算できる
println(mod_rec(50000, 3).getValue()) //StackOverflowError
}



動作の仕組み

以下はトランポリンをデコンパイルしたものです。末尾再帰最適化によってループによる処理になっていることが分かります。

重要なのはtrampoline = (Trampoline)((More)trampoline).getCalc().invoke();です。

通常の再帰では処理の戻り先がスタックに積まれていくためにスタックオーバーフローが発生しますが、代入によって処理を更新することでスタックが溜まらないようになっています。


TrampolineKt.decompiled.java

package trampoline;

import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import trampoline.Trampoline.Done;
import trampoline.Trampoline.More;

@Metadata(
mv = {1, 1, 13},
bv = {1, 0, 3},
k = 2,
d1 = {"\u0000\u0010\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\u001a%\u0010\u0000\u001a\u0002H\u0001\"\u0004\b\u0000\u0010\u00012\u0012\u0010\u0002\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u0002H\u00010\u00040\u0003¢\u0006\u0002\u0010\u0005\u001a#\u0010\u0000\u001a\b\u0012\u0004\u0012\u0002H\u00010\u0004\"\u0004\b\u0000\u0010\u00012\f\u0010\u0006\u001a\b\u0012\u0004\u0012\u0002H\u00010\u0004H\u0082\u0010¨\u0006\u0007"},
d2 = {"run", "T", "func", "Lkotlin/Function0;", "Ltrampoline/Trampoline;", "(Lkotlin/jvm/functions/Function0;)Ljava/lang/Object;", "trampoline", "TailrecTest"}
)
public final class TrampolineKt {
public static final Object run(@NotNull Function0 func) {
Intrinsics.checkParameterIsNotNull(func, "func");
Trampoline var1 = run((Trampoline)func.invoke());
if (var1 instanceof Done) {
return var1.getValue();
} else {
throw (Throwable)(new Exception());
}
}

private static final Trampoline run(Trampoline trampoline) {
while(!(trampoline instanceof Done)) {
if (!(trampoline instanceof More)) {
throw new NoWhenBranchMatchedException();
}

trampoline = (Trampoline)((More)trampoline).getCalc().invoke();
}

return trampoline;
}
}



何が嬉しいの?

以下のような相互再帰のコードは、通常の方法では末尾再帰にできませんが、トランポリンを使うことで末尾再帰最適化を行うことができます。


相互再帰で奇数・偶数判定

import trampoline.Trampoline

//奇数判定
fun odd(n: Int): Trampoline<Boolean> = when (n) {
0 -> Trampoline.Done(false)
else -> Trampoline.More { even(n - 1) }
}
//偶数判定
fun even(n: Int): Trampoline<Boolean> = when (n) {
0 -> Trampoline.Done(true)
else -> Trampoline.More { odd(n - 1) }
}

fun main(args: Array<String>) {
println(trampoline.run { even(10000) })
println(trampoline.run { odd(10000) })
}



終わりに

今回実装した内容は限定的なトランポリンで、やろうと思えば階乗やフィボナッチ数列(fibo(n) = fibo(n-1) + fibo(n-2)で計算する方)なんかもトランポリンで計算できるようですが、難しかった思いつくアルゴリズムはトランポリン無しで末尾再帰最適化ができたので触れないこととしました。

一応Scalaのライブラリにあるトランポリンはその辺りもちゃんとできるようです。

23日にMicroAd Advent Calenderで末尾再帰についてまた書く予定なので、もしよろしければそちらも読んでみて下さい。


参考にさせて頂いた記事


おまけ


今回使った奇数・偶数判定と剰余演算をまとめたもの

//再帰を使う計算の実装クラス

class RecCalc(private val n: Int){
//奇数判定
private fun odd(n: Int): Trampoline<Boolean> = when (n) {
0 -> Trampoline.Done(false)
else -> Trampoline.More { even(n - 1) }
}
//偶数判定
private fun even(n: Int): Trampoline<Boolean> = when (n) {
0 -> Trampoline.Done(true)
else -> Trampoline.More { odd(n - 1) }
}
//偶奇判定用プロパティ
val isEven: Boolean get() { return trampoline.run { even(n) } }
val isOdd: Boolean get() { return trampoline.run { odd(n) } }

//剰余演算
private fun mod(n: Int, m:Int): Trampoline<Int> {
if(n < m) return Trampoline.Done(n)
return Trampoline.More{ mod(n-m, m) }
}
fun mod(m: Int): Int{
return trampoline.run { mod(n, m) }
}
}






  1. Trampolinetrampolineが出てきますが、前者が型名で後者がパッケージ名です。