0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Javaで不動点コンビネータを活用してメモ化とトレース機能を実現する

Posted at

はじめに

不動点コンビネータはWikipediaでは以下のように説明されています。

不動点コンビネータ(ふどうてんコンビネータ、英: fixed point combinator、不動点結合子、ふどうてんけつごうし)とは、与えられた関数の不動点(のひとつ)を求める高階関数である。不動点演算子(ふどうてんえんざんし、英: fixed-point operator)、パラドキシカル結合子(英: paradoxical combinator)などとも呼ばれる。ここで関数fの不動点とは、f(x) = xを満たすようなxのことをいう。

この記事ではJavaでこれをどのように実現するかを示し、それを活用して関数のメモ化と、関数のトレース機能を実現する方法を記述します。不動点コンビネータそのものには深入りせず、応用に焦点をあてています。

この記事ではJava14とJUnit5のみを使用しています。
この記事で使用したすべてのコードはここにあります。
TestFixedPointCombinator

Javaでの不動点コンビネータの実現

Javaでの実現方法はY combinator - Rosetta Codeに記載されています。Y combinatorは不動点コンビネータのひとつです。

interface RecursiveFunction<F> extends Function<RecursiveFunction<F>, F> {
}

static <A, B> Function<A, B> Y(Function<Function<A, B>, Function<A, B>> f) {
    RecursiveFunction<Function<A, B>> r = w -> f.apply(x -> w.apply(w).apply(x));
    return r.apply(r);
}

@Test
void 不動点コンビネータのテスト() {
    Function<Integer, Integer> fib = Y(
        f -> n -> (n <= 2)
            ? 1
            : (f.apply(n - 1) + f.apply(n - 2)));
    Function<Integer, Integer> fac = Y(
        f -> n -> (n <= 1)
            ? 1
            : (n * f.apply(n - 1)));
    System.out.println("fib(10) = " + fib.apply(10));
    System.out.println("fac(10) = " + fac.apply(10));
}
fib(10) = 55
fac(10) = 3628800

再帰的なインタフェースの定義からはじまって、何やら難しそうなコードが並んでいますが、このコードを理解する必要はありません。同じことができるより単純なコードを以下に示します。このコードは再帰呼び出しによって階乗を計算するものです。

static <T, R> Function<T, R> fixedPointCombinator(Function<Function<T, R>, Function<T, R>> f) {
    return new Function<T, R>() {
        @Override
        public R apply(T t) {
            return f.apply(this).apply(t);
        }
    };
}

static Function<Function<Integer, Integer>, Function<Integer, Integer>> factorial =
    self -> n ->
        n <= 0
            ? 1
            : n * self.apply(n - 1);

@Test
void 単純化した不動点コンビネータのテスト() {
    System.out.println("factorial(10) = " + fixedPointCombinator(factorial).apply(10));
}
factorial(10) = 3628800

再帰呼び出しの階乗ならもっと簡潔に書くことができますが、このコードのポイントはfactorialの定義の中にfactorialの直接的な呼び出しを含んでいない点です。そもそも、ここでのfactorialはメソッド定義ではなくて、単なる変数の宣言なのでfactorial = の後にfactorialが登場したらコンパイルエラーになります。
不動点コンビネータはこのように明示的な再帰を書かずに再帰を実現するためのものです。
不動点コンビネータ - Wikipedia」には以下のように記述されています。

不動点コンビネータにより、第一級関数をサポートしているプログラミング言語において、明示的に再帰を書かずに再帰を実現する為に用いる事ができる。なお、一般にそういった言語では普通に再帰が使えるので、プログラミングにおいてはパズル的なテクニック以上の意味は無い。一方、循環なく関数の意味を定義する(できる)、ということは、計算理論の上では重要である。

もちろんJavaでは再帰的な関数(メソッド)は簡単に定義できるので、前述のコードはWikipedia的には 「パズル的なテクニック以上の意味は無い」 ということになります。ただこの実装で面白いと思うのは、R apply(T t)の部分です。これは実際に階乗を計算する機能を呼び出して、その結果を返しているだけですが、階乗計算の呼び出しは常にここで行われているという点です。
以下のようにすれば階乗計算の呼び出しの前後に処理を追加することができるということです。

        @Override
        public R apply(T t) {
            // 前処理
            R result = f.apply(this).apply(t);
            // 後処理
            return result;
        }

メモ化の実現

例えば「メモ化」のような処理です。一度計算した結果をどこかに保存しておけば無駄な再計算を防ぐことかでき、高速化することができる場合があります。
メモ化する不動点コンビネータは以下のようになります。

static <T, R> Function<T, R> memoize(Function<Function<T, R>, Function<T, R>> f) {
    return new Function<T, R>() {
        final Map<T, R> cache = new HashMap<>();

        @Override
        public R apply(T t) {
            R v = cache.get(t);
            if (v == null)
                cache.put(t, v = f.apply(this).apply(t));
            return v;
            // 以下の実装でもいいはずですがConcurrentModificationExceptionがスローされます。
            // return cache.computeIfAbsent(t, k -> f.apply(this).apply(k));
        }

        @Override
        public String toString() {
            return cache.toString();
        }
    };
}

static Function<Function<Integer, Integer>, Function<Integer, Integer>> fibonacci =
    self -> n ->
        n == 0 ? 0 :
        n == 1 ? 1 :
        self.apply(n - 1) + self.apply(n - 2);

@Test
void メモ化のテスト() {
    Function<Integer, Integer> memoizedFibonacci = memoize(fibonacci);
    System.out.println("fibonacci(10) = " + memoizedFibonacci.apply(10));
    System.out.println(memoizedFibonacci);
}
fibonacci(10) = 55
{0=0, 1=1, 2=1, 3=2, 4=3, 5=5, 6=8, 7=13, 8=21, 9=34, 10=55}

メモ化したフィボナッチ関数はtoString()でキャッシュの中身を返すので、関数オブジェクトそのものを印刷すると、それまでに蓄積したキャッシュの内容を確認することができます。

複数引数の実現

関数の引数がふたつ以上ある場合はどうすればいいでしょうか。例題として
竹内関数のメモ化を行ってみます。(竹内関数の詳細はWikipediaを参照してください)
Javaで普通に書くと以下のようになります。

static int tarai(int x, int y, int z) {
    if (x <= y)
        return y;
    else
        return tarai(tarai(x - 1, y, z),
                     tarai(y - 1, z, x),
                     tarai(z - 1, x, y));
}

recordによる複数引数の実現

引数がみっつあるので、これをひとまとめにするクラスを定義します。ここではJava14から追加されたrecordを使います。

static record Args(int x, int y, int z) {}

static Function<Function<Args, Integer>, Function<Args, Integer>> tarai =
    self -> a ->
        a.x <= a.y ?
            a.y :
            self.apply(new Args(self.apply(new Args(a.x - 1, a.y, a.z)),
                                self.apply(new Args(a.y - 1, a.z, a.x)),
                                self.apply(new Args(a.z - 1, a.x, a.y))));

@Test
void recordによる複数引数のメモ化() {
    Function<Args, Integer> memoizedTarai = memoize(tarai);
    System.out.println("tarai(3, 2, 1) = " + memoizedTarai.apply(new Args(3, 2, 1)));
    System.out.println("キャッシュの中身: " + memoizedTarai);
}
tarai(3, 2, 1) = 3
キャッシュの中身: {Args[x=1, y=3, z=2]=3, Args[x=2, y=2, z=1]=2, Args[x=3, y=2, z=1]=3, Args[x=1, y=1, z=3]=1, Args[x=2, y=1, z=3]=3, Args[x=0, y=3, z=2]=3}

ArgsクラスはMapにキーとして格納されるので、equals()hashCode()を適切に実装している必要がありますが、recordを使えば問題ありません。
メモ化の効果を測定してみます。

static String 時間測定(Supplier<String> s) {
    long start = System.currentTimeMillis();
    return s.get() + " : 所要時間 " + (System.currentTimeMillis() - start) + "ms";
}

@Test
void 通常の関数とrecordによる複数引数のメモ化の性能比較() {
    System.out.println(時間測定(() -> "通常の竹内関数           tarai(15, 7, 1) = " + tarai(15, 7, 1)));
    System.out.println(時間測定(() -> "メモ化竹内関数(record)   tarai(15, 7, 1) = " + memoize(tarai).apply(new Args(15, 7, 1))));
}
通常の竹内関数           tarai(15, 7, 1) = 15 : 所要時間 717ms
メモ化竹内関数(record)   tarai(15, 7, 1) = 15 : 所要時間 22ms

30倍くらい速くなっています。

カリー化による複数引数の実現

カリー化というすべてを単一引数の関数として実現する方法もあります。具体的にはtaraiに引数をひとつ渡すと、2引数の関数が返るようにします。さらに引数をひとつ渡すと1引数の関数が返り、最後にもうひとつ引数をを渡すと整数の結果が返るという具合です。

  • tarai -> Function<Integer, Function<Integer, Function<Integer, Integer>>>
  • tarai.apply(3) -> Function<Integer, Function<<Integer, Integer>>
  • tarai.apply(3).apply(2) -> Function<Integer, Integer>
  • tarai.apply(3).apply(2).apply(1) -> Integer
@Test
void カリー化による複数引数のメモ化() {
    Function<Integer, Function<Integer, Function<Integer, Integer>>> tarai =
        memoize(self -> x ->
            memoize(selfy -> y ->
                memoize(selfz -> z -> x <= y ? y
                    : self.apply(self.apply(x - 1).apply(y).apply(z))
                          .apply(self.apply(y - 1).apply(z).apply(x))
                          .apply(self.apply(z - 1).apply(x).apply(y)))));
    System.out.println("tarai(3, 2, 1) = " + tarai.apply(3).apply(2).apply(1));
    System.out.println("キャッシュの中身: " + tarai);
    System.out.println(時間測定(() -> "メモ化竹内関数(カリー化) tarai(15, 7, 1) = " + tarai.apply(15).apply(7).apply(1)));
}
tarai(3, 2, 1) = 3
キャッシュの中身: {0={3={2=3}}, 1={1={3=1}, 3={2=3}}, 2={1={3=3}, 2={1=2}}, 3={2={1=3}}}
メモ化竹内関数(カリー化) tarai(15, 7, 1) = 15 : 所要時間 1ms

recordを使うよりもカリー化の方が少し速いようです。キャッシュはMap<Integer, Map<Integer, Map<Integer, Integer>>>という入れ子のマップになります。(先頭の{0={3={2=3}},tarai(0, 3, 2)=3であることを示しています)

トレース機能

再帰的な関数は初心者にはわかりにくいですが、呼び出しのトレースを見ると理解しやすいかもしれません。トレース機能もメモ化と同様に不動点コンビネータを使って実現することができます。
トレース出力するために、引数をふたつ追加します。

  • String name - トレースに出力する関数の名前です。
  • Consumer<String> output - 実際にトレースを出力するConsumerです。例えば標準出力なら、System.out::println、ログ出力ならlogger::infoなどを指定します。
static <T, R> Function<T, R> trace(String name, Consumer<String> output, Function<Function<T, R>, Function<T, R>> f) {
    return new Function<T, R>() {
        int nest = 0;

        @Override
        public R apply(T t) {
            String indent = "  ".repeat(nest);
            output.accept(indent + name + "(" + t + ")");
            ++nest;
            R r = f.apply(this).apply(t);
            --nest;
            output.accept(indent + r);
            return r;
        }
    };
}

@Test
void トレースのテスト() {
    System.out.println("fibonacci(6) = " + trace("fibonacci", System.out::println, fibonacci).apply(6));
}
fibonacci(6)
  fibonacci(5)
    fibonacci(4)
      fibonacci(3)
        fibonacci(2)
          fibonacci(1)
          1
          fibonacci(0)
          0
        1
        fibonacci(1)
        1
      2
      fibonacci(2)
        fibonacci(1)
        1
        fibonacci(0)
        0
      1
    3
    fibonacci(3)
      fibonacci(2)
        fibonacci(1)
        1
        fibonacci(0)
        0
      1
      fibonacci(1)
      1
    2
  5
  fibonacci(4)
    fibonacci(3)
      fibonacci(2)
        fibonacci(1)
        1
        fibonacci(0)
        0
      1
      fibonacci(1)
      1
    2
    fibonacci(2)
      fibonacci(1)
      1
      fibonacci(0)
      0
    1
  3
8
fibonacci(6) = 8

これを見ると同じ引数で何度も呼び出されていることがわかります。そこでメモ化とトレース機能を同時に実装した不動点コンビネータも作って両者を比較してみます。キャッシュから得られた結果の後には(cache)の文字を出力するようにします。

static <T, R> Function<T, R> memoizeTrace(String name, Consumer<String> output, Function<Function<T, R>, Function<T, R>> f) {
    return new Function<T, R>() {
        Map<T, R> cache = new HashMap<>();
        int nest = 0;

        @Override
        public R apply(T t) {
            String indent = "  ".repeat(nest);
            output.accept(indent + name + "(" + t + ")");
            ++nest;
            R result = cache.get(t);
            String from = "";
            if (result == null)
                cache.put(t, result = f.apply(this).apply(t));
            else
                from = " (cache)";
            --nest;
            output.accept(indent + result + from);
            return result;
        }

        @Override
        public String toString() {
            return cache.toString();
        }
    };
}

@Test
void メモ化トレース() {
    System.out.println("トレース       fibonacci(6) = " + trace("fibonacci", System.out::println, fibonacci).apply(6));
    System.out.println("メモ化トレース fibonacci(6) = " + memoizeTrace("fibonacci", System.out::println, fibonacci).apply(6));
}
}

以下はそれぞれのトレース出力の差分を取ったものです。

fibonacci-memoize-trace.png

おわりに

不動点コンビネータを使って関数の定義を変えずにメモ化やトレース機能を追加することができました。ここで示した以外にも、特定の入力データを変更して関数適用したり、特定の出力データを変更して返すようなパッチ的な使い方もできますし、入出力データの統計を取ったりすることもできると思います。これらは本来の不動点コンビネータの主旨とは違っていますが、現実のプログラミングで役に立つのではないでしょうか。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?