この記事はオプティマインド Location Tech Advent Calendar 2022の21日目の記事となります。
こんにちは。株式会社オプティマインドの最適化チーム所属の伊豆原(イズハラ)と申します。
去年の年末に参加したアドベントカレンダーでは関数型言語に触れてみようと思いまして、Haskellのモナドと圏論のモナドとの関係性を勉強しました。
今年はまた年に1回の関数型言語を勉強する機会として、継続というものについて触れてみたいと思います。継続そのものは関数型言語特有のものではないのですが、Schemeという関数型言語の大きな特徴として使われているようです。WikibooksのScheme/継続のページを見ますと、なんでも"継続は計算の(デフォルトの)未来全体を表している"とかなんとか。
でもSchemeまったく知らないので、今回はC++で頑張ってみます!
継続
そもそも今まで継続というものに触れたことが無かったので、"継続とかは何か?"について書くのは私には難しいです。なのでこの記事では簡単のため、Wikipediaの以下のページで書かれている継続モナド(Continuation monad)をもって継続として、理解していきたいと思います。
リンク先によると継続(モナド)$Cont$は、適当な値域の型$R$を固定して以下のように定義されています。
Cont_R(T) := (T \to R) \to R
つまり関数$f:T\to R$を引数にし型$R$の値を取る関数を考えるものになります。型$T$の変数$x$は下式のように自然に$Cont(T)$の元$\bar{x}$と見做せるので、直感的には型$T$をそのものを考えるよりも少し広いものを扱える形になります(例えば常に定数を返すものなど)。
$$\bar{x}(f:T\to R):= f(x)$$
型を愚直に作りますと以下のような形になります。
template<typename A, typename B>
using Hom = std::function<B(A)>;
template<typename T, typename R>
using Cont = Hom<Hom<T, R>, R>;
例えば普通の整数を返す階乗関数は以下のように実装されますが、
int fact(int n) {
if (n == 0) {
return 1;
} else {
return fact2(n - 1) * n;
}
}
整数を引数にとる継続を返す階乗関数は以下のように実装されます。ややこしいですね。
Cont<int, int> factCont(int n) {
return [n](Hom<int, int> f) -> int{
if (n == 0) {
return f(1);
} else {
return fact(n - 1)([n, f](int m) {return f(m * n);});
}
};
}
わざわざ継続を返す関数を書かずとも、モナドの関数への作用や自然変換$\eta: id \to Cont$(とついでに$\mu: Cont^2 \to Cont$)を準備しておけば下記のように通常の階乗関数から継続版を作ることもできます。
(各ラムダ式は参照キャプチャにできる部分も多いと思いますが、とりあえず動かしたいのと実装の簡易のためすべてコピーキャプチャを使用しています)。
#include<iostream>
#include<functional>
template<typename A, typename B>
using Hom = std::function<B(A)>;
template<typename T, typename R>
using Cont = Hom<Hom<T, R>, R>;
// \eta : id -> Cont
template<typename T, typename R>
Cont<T, R> eta(T x) {
return [x](Hom<T, R> f) -> R {
return f(x);
};
}
// \mu : Cont^2 -> Cont
template<typename T, typename R>
Cont<T, R> mu(Cont<Cont<T, R>, R> f) {
return [f](Hom<T, R> g) -> Cont<T, R> {
return f(eta(g));
};
}
// f: x -> y => cont(f): Cont(x) -> Cont(y)
template<typename T, typename T2, typename R>
Hom<Cont<T, R>, Cont<T2, R>> cont(Hom<T, T2> f) {
return [f](Cont<T, R> g) -> Cont<T2, R> {
return [f, g](Hom<T2, R> h) -> R {
Hom<T, R> k = [f, h](T x) -> R {
return h(f(x));
};
return g(k);
};
};
}
// 通常の階乗関数: int -> int
int fact(int n) {
if (n == 0) {
return 1;
} else {
return fact(n - 1) * n;
}
}
// 継続を返す階乗関数: int ( -> Cont(int) ) -> Cont(int)
template<typename R>
Hom<int, Cont<int, R>> factCont = [](int n) {
return cont<int, int, R>(fact)(eta<int, R>(n));
};
int main() {
Hom<int, int> id = [](int x) {return x;};
Hom<int, void> printPlus1 = [](int x) {std::cout << x + 1 << std::endl; return;};
std::cout << factCont<int>(5)(id) << std::endl; // -> 120
factCont<void>(5)(printPlus1); // -> 121
return 0;
}
上のソースコードでは最後に値域$R$の異なる2つの関数id
とprintPlus1
を渡していますが、両方とも動いてますね。これらはつまり継続$(T\to R) \to R$に$T\to R$が与えられて型$R$の値を得た形になります。
以上で、この記事における継続の雰囲気は掴めましたかと思います。せっかくなので応用も見てみたいということで、よく継続とともに語られるcall//cc(call-with-current-continuation)にも触れたいと思います。
call/cc
call/ccを学ぶにあたって以下のページを参考にさせて頂きました。
- http://www.stdio.h.kyoto-u.ac.jp/jugyo1/scheme/SchemeNotes/continuation.html
- https://www.shido.info/lisp/callcc.html
さて、call/ccが何者かと言いますと、Wikipediaのモナドのページにcall/ccの型が書かれています。
\text{call/cc}: ((T\to (T'\to R) \to R) \to (T \to R) \to R) \to (T \to R) \to R \\
f \mapsto k \mapsto f(t\mapsto x \mapsto kt)k
なにか問題でも?と言われそうですが、分からんですね。どうも大域脱出とか再帰的処理に使えるようです。とりあえず$f$とか$k$をとかをバラしますと次のようになります。
\begin{align}
f&: (T \to Cont_R(T')) \to Cont_R(T) \\
k&: T \to R \\
t&: T \\
x&: T' \to R
\end{align}
つまりはこのよく分からない型の$f$を渡すと継続$Cont_R(T)$を返す関数がcall/ccのようです。定義に従って実装してみます。
template<typename T, typename T2, typename R>
Cont<T, R> call_cc(Hom<Hom<T, Cont<T2, R>>, Cont<T, R>> f) {
return [f](Hom<T, R> k) -> R {
Hom<T, Cont<T2, R>> cc = [k](T t) {
return [k, t](Hom<T2, R> x) -> R {
return k(t);
};
};
return f(cc)(k);
};
}
こうしてみると、call/ccの返り値である継続$Cont_R(T)$に渡される$k:T\to R$がcc
の内部に取り込まれているように見えますね。ここがcc
使用時に発生する大域脱出的な動きの要のようです。
また、型$T'\to R$の部分である$x$は何にも使われていないようです。つまり$Cont_R(T')$の部分は"どんな$x:T'\to R$を与えられても固定値$k(t)$を返す"継続ですかね。なので後の実装の簡単のため以下では$T=T'$としちゃっています(いいんだろうか)。
では大域脱出での使われ方を試してみます。整数のリストが与えられた時にその全要素の積を取り、ただし途中に0があったら大域脱出して0を返すような関数を作ります。これはSchemeでは次のように書けるようです。
(define-syntax print (syntax-rules () ((_ x) (begin (display x)(newline))))) ; 出力用
(define mul (lambda (x y) (begin (display "mul:") (print (* x y)) (* x y)))) ; ログ付き掛け算
(define (prod x)
(call/cc (lambda (cc) ; call/cc!
(letrec (
(_prod (lambda (y) (print y)
(let (
(result
(cond
((null? y) (begin (print "at end") 1))
((= (car y) 0) (begin (print "find 0") (cc 0)))
(#t (mul (car y) (_prod (cdr y))))
)))
result))))
(_prod x)))))
(print "# (1,2,3,4,5)")
(define v '(1 2 3 4 5))
(print (prod v))
(print "# (1,2,3,0,5)")
(define v '(1 2 3 0 5))
(print (prod v))
実際にGNU Guileを使って動かしてみますと、以下のような出力になります。要素に0が無い時は再帰的に積を取っていき、要素に0が見つかった場合は1回も掛け算をせずに0を返していることが分かります。素晴らしい。
# (1,2,3,4,5)
(1 2 3 4 5)
(2 3 4 5)
(3 4 5)
(4 5)
(5)
()
at end
mul:5
mul:20
mul:60
mul:120
mul:120
120
# (1,2,3,0,5)
(1 2 3 0 5)
(2 3 0 5)
(3 0 5)
(0 5)
find 0
0
ではこれをC++で再現してみます。再帰ごとに継続が積まれていき、0を検知したらcc(0)
を返します。cc(0)
はどのような$T' \to R$が与えられても$k(0)$を返す継続になり、この後にどのような継続が積まれても($Cont$の射での対応であるcont
の作り方から)無視されます。
#include<iostream>
#include<vector>
#include<functional>
// 出力用
void print() { std::cout << std::endl; }
template <class Head, class... Tail>
void print(Head&& head, Tail&&... tail) {
std::cout << head << ",";
print(std::forward<Tail>(tail)...);
}
// vectorのi番目以降をprint
template <typename A>
void vprint(std::vector<A> v, int i) {
for (int j = i; j < v.size(); j++) std::cout << v.at(j) << ",";
std::cout << std::endl;
}
template<typename A, typename B>
using Hom = std::function<B(A)>;
template<typename T, typename R>
using Cont = Hom<Hom<T, R>, R>;
// \eta : id -> Cont
template<typename T, typename R>
Cont<T, R> eta(T x) {
return [x](Hom<T, R> f) -> R {
return f(x);
};
}
// \mu : Cont^2 -> Cont は使わないので省略
// f: x -> y => cont(f): Cont(x) -> Cont(y)
template<typename T, typename T2, typename R>
Hom<Cont<T, R>, Cont<T2, R>> cont(Hom<T, T2> f) {
return [f](Cont<T, R> g) -> Cont<T2, R> {
return [f, g](Hom<T2, R> h) -> R {
Hom<T, R> k = [f, h](T x) -> R {
return h(f(x));
};
return g(k);
};
};
}
// 簡単のためT=T'
template<typename T, typename R>
Cont<T, R> call_cc(Hom<Hom<T, Cont<T, R>>, Cont<T, R>> f) {
return [f](Hom<T, R> k) -> R {
Hom<T, Cont<T, R>> cc = [k](T t) {
return [k, t](Hom<T, R> x) -> R {
return k(t);
};
};
return f(cc)(k);
};
}
// 掛け算の継続版
Hom<int, int> mul(int n) {
return [n](int x) {print("mul", n * x); return n * x;};
}
template<typename R>
Cont<int, R> _prod(Hom<int, Cont<int, R>> cc, std::vector<int>& v, int i) {
vprint(v, i);
if (i == v.size()) { print("at end"); return eta<int, R>(1); }; // 継続を返す
if (v[i] == 0) { print("find 0"); return cc(0); }; // ここの返り値は本来(T'->R)->R (T=T'はここの簡単のため)
return cont<int, int, R>(mul(v[i]))(_prod<R>(cc, v, i + 1)); // 整数の継続と整数の継続を掛け算の継続で掛けてる
}
template<typename R>
Cont<int, R> prod(std::vector<int>& v, int i) {
return call_cc<int, R>([i, &v](Hom<int, Cont<int, R>> cc) -> Cont<int, R> {
return _prod<R>(cc, v, i);
});
}
int main() {
Hom<int, void> print_int = [](int x) {std::cout << x << std::endl; return;};
print("# (1,2,3,4,5)");
std::vector<int> v = { 1,2,3,4,5 };
prod<void>(v, 0)(print_int);
print("# (1,2,3,0,5)");
std::vector<int> v1 = { 1,2,3,0,5 };
prod<void>(v1, 0)(print_int);
return 0;
}
出力は以下のようになります。ちゃんとScheme版のを再現できるようですね。丸写ししているので当然ではありますが……
# (1,2,3,4,5),
1,2,3,4,5,
2,3,4,5,
3,4,5,
4,5,
5,
at end,
mul,5,
mul,20,
mul,60,
mul,120,
mul,120,
120
# (1,2,3,0,5),
1,2,3,0,5,
2,3,0,5,
3,0,5,
0,5,
find 0,
0
ではもう一つ別の例で、計算の進行状況を保持している様子を、計算過程が積まれるcc
の再利用で見てみます。例には階乗関数を使用します。
template<typename R>
class Fact {
public:
Hom<int, Cont<int, R>> save;
Cont<int, R> calc(int n) {
print("fact", n);
if (n == 1) {
return call_cc<int, R>([this](Hom<int, Cont<int, R>> cc) -> Cont<int, R> {
this->save = cc;
return eta<int, R>(1);
});
} else {
return cont<int, int, R>(mul(n))(calc(n - 1));
};
}
};
int main() {
Hom<int, int> id = [](int x) {return x;};
Hom<int, int> plus123 = [](int x) {return x + 123;};
Fact<int> fact;
print(fact.calc(5)(id)); // 120
print(fact.save(10)(plus123)); // 1200
return 0;
}
流れとしてはまずcalc
で5の階乗が計算され、同時にcc
がsave
に保持されます。この時save
には(eta<int, R>(1)
以外の)階乗計算過程の処理が積まれて保持されている形になるので、再利用すれば再び階乗計算が行われます。またその計算の最後では大域脱出的な動きになるので、save
で返る継続には何を渡しても無視されます(上でいうplus123
)。
以上、call/ccの紹介でした。下記のサイトを見ますといろいろなことが出来そうな関数ですので、遊んでみると楽しいかと思われます。
まとめ
以上、継続とcall/ccについて触れてみました。本当はクリスマスも近いので、クリスマス限定->限定継続(delimited continuation) とかいってshift
やreset
などにも触れてみたかったのですが、これはまたの機会にしたいと思います。
読んで頂きありがとうございました!
さいごに
オプティマインドでは「多様性が進んだ世の中でも、全ての人に物が届く世界を持続可能にする」という物流業界の壮大な社会課題を解決すべく、一緒に働く仲間を大募集中です。少しでも興味が湧いた方はカジュアル面談も大歓迎ですので、気軽にお声がけください!