Java でメモ化を使いたい
メモ化というのはメソッド呼び出しの引数と返り値の組をメモリ上にキャッシュしておき、同じ引数で呼び出された際、キャッシュの方から値を引くことによって計算時間を節約するテクニックのことだ。このようなテクニックはどのような言語でも適用でき、もちろん Java においても同様だ。
しかし、何も準備せず書き始めると少しばかり記述が煩雑になってしまう。つまり、メソッド一つにつき一つづつキャッシュ用のマップを生成し、メソッド毎に引数をキーにキャッシュ用マップを検索し、あったらそちらを返し、なかったら計算する…というロジックを書くことになってしまいがちということだ。
そこで、少しばかりのコードを準備する。(以前説明した Tuple を使用している)
Memoize.java
public class Memoize {
protected static class MemoFunc<I, O> {
protected Map<I, O> cache = new HashMap<>();
protected Function<I, O> body = null;
protected O get(I i) {return cache.computeIfAbsent(i, body);}
protected void bind(Function<I, O> body_) {
if (body != null) throw new RuntimeException("binding exists already"); else body = body_;
}
}
public static class MemoFunc1<I1, O> extends MemoFunc<Tuple1<I1>, O> {
public O call(I1 i1) {return get(new Tuple1<>(i1));}
public MemoFunc1<I1, O> defun(_MemoFunc1<I1, O> body) {bind(i -> body.calculate(i.car)); return this;}
}
@FunctionalInterface public static interface _MemoFunc1<I1, O> {public O calculate(I1 i1);}
public static class MemoFunc2<I1, I2, O> extends MemoFunc<Tuple2<I1, I2>, O> {
public O call(I1 i1, I2 i2) {return get(new Tuple2<>(i1, i2));}
public MemoFunc2<I1, I2, O> defun(_MemoFunc2<I1, I2, O> body) {bind(i -> body.calculate(i.car, i.cdr.car)); return this;}
}
@FunctionalInterface public static interface _MemoFunc2<I1, I2, O> {public O calculate(I1 i1, I2 i2);}
public static class MemoFunc3<I1, I2, I3, O> extends MemoFunc<Tuple3<I1, I2, I3>, O> {
public O call(I1 i1, I2 i2, I3 i3) {return get(new Tuple3<>(i1, i2, i3));}
public MemoFunc3<I1, I2, I3, O> defun(_MemoFunc3<I1, I2, I3, O> body) {bind(i -> body.calculate(i.car, i.cdr.car, i.cdr.cdr.car)); return this;}
}
@FunctionalInterface public static interface _MemoFunc3<I1, I2, I3, O> {public O calculate(I1 i1, I2 i2, I3 i3);}
public static class MemoFunc4<I1, I2, I3, I4, O> extends MemoFunc<Tuple4<I1, I2, I3, I4>, O> {
public O call(I1 i1, I2 i2, I3 i3, I4 i4) {return get(new Tuple4<>(i1, i2, i3, i4));}
public MemoFunc4<I1, I2, I3, I4, O> defun(_MemoFunc4<I1, I2, I3, I4, O> body) {bind(i -> body.calculate(i.car, i.cdr.car, i.cdr.cdr.car, i.cdr.cdr.cdr.car)); return this;}
}
@FunctionalInterface public static interface _MemoFunc4<I1, I2, I3, I4, O> {public O calculate(I1 i1, I2 i2, I3 i3, I4 i4);}
public static class MemoFunc5<I1, I2, I3, I4, I5, O> extends MemoFunc<Tuple5<I1, I2, I3, I4, I5>, O> {
public O call(I1 i1, I2 i2, I3 i3, I4 i4, I5 i5) {return get(new Tuple5<>(i1, i2, i3, i4, i5));}
public MemoFunc5<I1, I2, I3, I4, I5, O> defun(_MemoFunc5<I1, I2, I3, I4, I5, O> body) {bind(i -> body.calculate(i.car, i.cdr.car, i.cdr.cdr.car, i.cdr.cdr.cdr.car, i.cdr.cdr.cdr.cdr.car)); return this;}
}
@FunctionalInterface public static interface _MemoFunc5<I1, I2, I3, I4, I5, O> {public O calculate(I1 i1, I2 i2, I3 i3, I4 i4, I5 i5);}
}
こんな風に準備しておくと、たとえばフィボナッチ数列の計算(これは遅延評価やメモ化が高い効果を発揮することで有名な計算だ)が以下のように簡潔に書けるようになる。
main.java
public static MemoFunc1<Long, Long> fibonacci = new MemoFunc1<Long, Long>() {{
defun(
n -> {
if (n == 0) return 0L;
if (n == 1) return 1L;
return fibonacci.call(n-1) + fibonacci.call(n-2);
}
);
}};
public static void main(String[] args) {
System.out.println(fibonacci.call(50L));
}
こんなことまでしてやることがフィボナッチかよ、とか言わないで欲しい。弊社は金融系のシステム開発を営む SIer だが、デリバティブとか仕組債の現在価値を計算するプログラムでは論文等に出現する式を「できるだけそのままの見た目で」記述できる(宣言的にプログラミングできる)ことは大きなメリットなのだ。ただ、メモ化をせず、遅延評価等も備えない言語でそれをやってしまうと同じ計算が無数に発生し、遅くて使いものにならないケースが多い。そのような用途でここに挙げたようなものと似たコードが現実に力を発揮している。