LoginSignup
11
9

More than 5 years have passed since last update.

C++でKnuth-Bendixの完備化アルゴリズムを頑張って実装した話

Last updated at Posted at 2015-07-11

はじめに

動機

最強の代数処理言語を作りたい……。LALRパーサジェネレーターを作った、高速多倍長整数型や可変長精度の浮動少数点数型も作った、数式処理のために記号積分やグレブナー基底も勉強した。あと足りないものはなんだ?そうだ、Knuth-Bendixの完備化アルゴリズムでTRSの書き換え規則を自動完備化するプロセスが足りない。作らなければ……。

Term Rewriting Sysmtem について

Wikipedia - 項書き換え
Term Rewriting System (TRS) は式やワードの列を規則に従って書き換えていき、これ以上書き換えができない最も『小さな』状態(標準形)になるまで繰り返します。これを簡約といいます。

合流性 (Confluence)

Wikipedia - 合流性
一般的に簡約対象の項の列が、標準形やある特定の同じ形に至るまでには何通りかの書き換え規則適用の組み合わせがあります。例えば次の規則 $R_{1}$ には項 $a$ が $d$ に至るまで $a \to b \to d$ と $a \to c \to d$ の二通りの書き換えの手順があり、最終的にどちらも同じ結果になっています。これを合流性といいます。

R_{1} = \left\{\ a \to b,\ a \to c,\ b \to d,\ c \to d\ \right\}

停止性 (Termination)

書き換え規則によっては結果が一意に定まらず、収束しないものがあります。次の書き換え規則 $R_{2}$ を見てみましょう。

R_{2} = \left\{\ a \to b,\ b \to c,\ c \to a\ \right\}

$a$ が与えられ書き換え規則を適用していくと、 $a \to b$ 、 $b \to c$ 、 $c \to a$ と元の項に戻ってしまい、永久に書き換えが停止しないことが分かります。この性質がなく、書き換え対象について必ず手続きが停止することを停止性といいます。

危険対 (Critical pairs)

あるひとつの項に対して2つの書き換え規則が適用でき、それぞれの適用に合流性がないとき、その書き換え結果2つを危険対といいます。危険対がある書き換え規則は正しい簡約結果を導出できません。

Knuth-Bendix の完備化アルゴリズムについて

Wikipedia - クヌース・ベンディックス完備化アルゴリズム
TRS において与えられた書き換え規則に基づいて危険対を検出し、収束させるために規則を更に追加する手続きです。このアルゴリズムは必ず停止する保証がないため、厳密にはセミ・アルゴリズム (Semi-Algorithm) と呼称します。ただし、停止した場合に得られた結果の完備化された書き換え規則は、書き換え対象に対して停止性が保証されます。

既存の実装

TRS の教科書として、 Franz Baarder 、 Tobias Nipkow 著の Term Rewriting and All That が有名で ( Amazon でも買えます ) 、こちらに ML による実装が紹介されておりサポートページからもコードを閲覧できます
また、以前に私も上記の実装を OCaml の処理系で動作するコードへ移植しており、 今回はそれを更に C++ へ移植したコードをアルゴリズム全体と共に紹介したいと思います。

実装と詳細

C++ 実装は gcc 4.9.1 を用いてコンパイルと動作確認を行いました。

クラス

C++ でまとまった実装を行うためにまずクラスを記述します。項の型と、項の比較を行う Functor を template 引数に取ります。

template<class ValueType, class Less = std::less<ValueType>>
class term_rewriting_system{
public:
    using type = ValueType;
    // ...
};

頻出する関数

OCaml では再帰関数により繰り返しを実現していますが、その多くが C++ では効率の良いループによる記述に変換可能です。

null

対象の列を表現したリストが空かどうか判定します。

  • OCaml
let null s =
  match s with
    | []       -> true
    | (_ :: _) -> false
  • C++
    template<class Type>
    static bool null(const std::vector<Type> &s){
        return s.empty();
    }

zip

与えられた2つの列の要素を結合して新たなリストを生成する関数です。

  • OCaml
let rec zip =
  function ([], [])           -> []
         | (x :: xs, y :: ys) -> (x, y) :: zip (xs, ys)
         | (_, _)             -> raise INVALID_ARGUMENT
  • C++
    template<class X, class Y>
    static std::vector<std::pair<X, Y>> zip(const std::vector<X> &xs, const std::vector<Y> &ys){
        std::vector<std::pair<X, Y>> ret;
        auto x = xs.begin();
        auto y = ys.begin();
        while(x != xs.end()){
            ret.push_back(std::make_pair(*x, *y));
            ++x;
            ++y;
        }
        return ret;
    }

forall

与えられた列について、真偽値を返す関数 p で判定したとき、全ての要素が真になるかどうかを判定します。

  • OCaml
let rec forall p =
  function []        -> true
         | (x :: xs) -> p x && forall p xs
  • C++
    template<class X>
     static bool forall(std::function<bool(const X&)> p, const std::vector<X> &xs){
        bool r = true;
        for(auto &x : xs){
            r = r && p(x);
            if(!r){ break; }
        }
        return r;
    }

exists

与えられた列について、真偽値を返す関数 p で判定したとき、1つでも要素が真になるかどうかを判定します。

  • OCaml
let rec exists f =
  function []        -> false
         | (y :: ys) -> if f y then true else exists f ys
  • C++
    template<class X>
    static bool exists(std::function<bool(const X&)> p, const std::vector<X> &xs){
        bool r = false;
        for(auto &x : xs){
            r = r || p(x);
            if(r){ break; }
        }
        return r;
    }

map

要素を変換する関数 f によって新たな要素へと変換した列を生成します。
C++ 実装では、列の const reference を直接受け取り変換するものと、効率化のためにレンジベースで iterator を与えられたときに処理を行う関数の2つを用意します。

  • OCaml
let rec map f =
  function []        -> []
         | (x :: xs) -> (f x) :: map f xs
  • C++
    template<class X, class Y>
    static std::vector<Y> map(std::function<Y(const X&)> f, const std::vector<X> &xs){
        std::vector<Y> r;
        r.reserve(xs.size());
        for(auto &e : xs){
            r.push_back(f(e));
        }
        return r;
    }

    template<class X, class Y, class Iter>
    static std::vector<Y> map(std::function<Y(const X&)> f, Iter first, Iter last){
        std::vector<Y> r;
        r.reserve(last - first);
        for(; first != last; ++first){
            r.push_back(f(*first));
        }
        return r;
    }

concat

二重にネストされた列が与えられたときに、それらを全て結合した新たな列を生成します。

  • OCaml
let rec concat =
  function []        -> []
         | (x :: xs) -> x @ (concat xs)
  • C++
    template<class X>
    static std::vector<X> concat(const std::vector<std::vector<X>> &xs){
        std::vector<X> rs;
        for(auto  &s : xs){
            rs.insert(rs.end(), s.begin(), s.end());
        }
        return rs;
    }

allapp

列の要素を処理する関数 f 、 g と列が与えられたときに、列の要素全てを f 、 g を使って適用します。
C++ 実装では g の実装は効率化のために iterator ベースとなっています。

  • OCaml
let rec allapp f g =
  function []        -> ()
         | (x :: xs) -> f x; g xs; allapp f g xs
  • C++
    template<class X>
    static void allapp(std::function<void(const X&)> f, std::function<void(typename std::vector<X>::const_iterator, typename std::vector<X>::const_iterator)> g, const std::vector<X> &xs){
        for(auto iter = xs.begin(); iter != xs.end(); ++iter){
            f(*iter);
            g(iter + 1, xs.end());
        }
    }

データ構造

vname

比較時にあらゆる項にマッチングする項である変数の型を定義します。項の型と、その変数のナンバーから成っています。
C++ では、比較には term_rewriting_system 型の2番目の template parameter である Less functor を利用します。

  • OCaml
type vname = string * int
let compare_vname (v, i) (w, j) = String.compare v w
  • C++
    struct vname{
        type str;
        int i = 0;

        bool operator ==(const vname &other) const{
            return str == other.str && i == other.i;
        }
    };

    static vname make_vname(const type &str, int i = 0){
        vname v;
        v.str = str;
        v.i = i;
        return v;
    }

term 、 term_list

term は項を表現します。また、 term は内部で term のリストを再帰的に保持するため、 C++ 実装では term_list を別途用意し、 boost::variant を継承した vnameterm_list 両方を表現できるクラスとして term を実装しています。

  • OCaml
type term =
  | V of vname
  | T of string * term list
  • C++
    struct term;

    struct term_list{
        type str;
        std::vector<term> list;

        bool operator ==(const term_list &other) const{
            return str == other.str && list == other.list;
        }
    };

    struct term : public boost::variant<vname, term_list>{
        using base = boost::variant<vname, term_list>;
        using base::which;

        enum{
            type_variable = 0, type_term_list
        };

        term(const vname &v) : base(v){}
        term(vname &&v) : base(v){}
        term(const term_list &tl) : base(tl){}
        term(term_list &&tl) : base(tl){}
        term() : base(vname()){}
        term(const term &other) : base(vname()){
            switch(other.which()){
                case type_variable:
                    static_cast<base&>(*this) = boost::get<vname>(other);
                    break;

                case type_term_list:
                    static_cast<base&>(*this) = boost::get<term_list>(other);
                    break;
            }
        }

        term(term &&other) : base(vname()){
            switch(other.which()){
                case type_variable:
                    static_cast<base&>(*this) = std::move(boost::get<vname>(other));
                    break;

                case type_term_list:
                    static_cast<base&>(*this) = std::move(boost::get<term_list>(other));
                    break;
            }
        }

        ~term() = default;

        term &operator =(const term &other){
            switch(other.which()){
                case type_variable:
                    static_cast<base&>(*this) = boost::get<vname>(other);
                    break;

                case type_term_list:
                    static_cast<base&>(*this) = boost::get<term_list>(other);
                    break;
            }

            return *this;
        }

        bool operator ==(const term &other) const{
           if(which() != other.which()){
                return false;
            }else{
                if(which() == type_variable){
                    return boost::get<vname>(*this) == boost::get<vname>(other);
                }else{
                    return boost::get<term_list>(*this) == boost::get<term_list>(other);
                }
            }
        }
    };

ids

項のペアをリストで保持します。 Critical Pairs や書き換え規則の左辺・右辺を表現するために使用します。

  • OCaml
type ids = (term * term) list
  • C++
    using ids = std::vector<std::pair<term, term>>;

Utilities

項表現におけるのリストを直接生成する関数、同じく変数を生成する関数、定数を生成する関数、関数を生成する関数を今後のために記述しておきます。
boost::get<vname>(term) でキャストされた項はあらゆる項にマッチする変数を示し、ネストした項を持つ boost::get<term_list>(term) にキャストされた項は引数を持つ関数を指し、同様にネストされた項を1つも持たない boost::get<term_list>(term) は定数を指します。

  • C++
    static term make_term_list(const type &str, const std::vector<term> &list){
        return term_list({ str, list });
    }

    static term make_variable(const type &str, int i = 0){
        return make_vname(str, i);
    }

    static term make_variable(const vname &v){
        return v;
    }

    static term make_constant(const type &str){
        return make_term_list(str, std::vector<term>());
    }

    static term make_function(const type &str, const std::vector<term> &ts){
        return make_term_list(str, ts);
    }

subst

変数と項の組をリストで保持します。後にindom関数などで用います。

  • OCaml
type subst = (vname * term) list

let get_vname vt =
  match vt with
    | (v, t) -> v

let get_term vt =
  match vt with
    | (v, t) -> t
  • C++
    using vname_term = std::pair<vname, term>;
    using subst = std::vector<vname_term>;

    static const vname &get_vname(const vname_term &vt){
        return vt.first;
    }

    static const term &get_term(const vname_term &vt){
        return vt.second;
    }

汎用の関数から項書き換えまで

add_subst

subst に要素を追加します。

  • OCaml
let empty_subst = []

let add_subst s x t = (x, t) :: s
  • C++
    static subst add_subst(const subst s, const vname &x, const term &t){
        subst r = { std::make_pair(x, t) };
        r.insert(r.begin(), s.begin(), s.end());
        return r;
    }

indom 、 app

subst から一致する vname を持つ要素を検索します。

  • OCaml
let indom x s = exists (fun (y, _) -> x = y) s

let rec app s x =
  match s with
    | (y, t) :: rest -> if x = y then t else app rest x
    | _              -> raise INVALID_ARGUMENT
  • C++
    static bool indom(const vname &x, const subst &s){
        std::function<bool(const vname_term&)> f = [&](const vname_term &y){
            return x.str == y.first.str && x.i == y.first.i;
        };
        return exists(f, s);
    }

    static const term &app(const subst &s, const vname &x){
        for(auto &yt : s){
            if(yt.first == x){
                return yt.second;
            }
        }
        throw;
    }

lift

再帰的に自身を適用しつつ、一致する term が現れるまで subst を検索し続けます。

  • OCaml
let rec lift s t =
  match t with
    | V x       -> if indom x s then app s x else V x
    | T (f, ts) -> T (f, map (lift s) ts)
  • C++
    static term lift(const subst &s, const term &t){
        if(t.which() == term::type_variable){
            const vname &x = boost::get<vname>(t);
            if(indom(x, s)){
                return app(s, x);
            }else{
                return t;
            }
        }else{
            const term_list &tl = boost::get<term_list>(t);
            std::function<term(const term&)> f = [&s](const term &a){ return lift(s, a); };
            return make_term_list(tl.str, map(f, tl.list));
        }
    }

occurs

エラーが発生していないかどうか等をチェックする関数です。こちらも再帰的に自身を適用して実行していきます。

  • OCaml
let rec occurs x t =
  match t with
    | V y       -> x = y
    | T (_, ts) -> exists (occurs x) ts
  • C++
    static bool occurs(const vname &x, const term &t){
        if(t.which() == term::type_variable){
            return x == boost::get<vname>(t);
        }else{
            std::function<bool(const term&)> f = [&x](const term &a){ return occurs(x, a); };
            return exists(f, boost::get<term_list>(t).list);
        }
    }

一意化の例外

アルゴリズム中、書き換え対象の一意化に失敗したかどうかを検出するために例外を捕捉する専用の記述を行います。

  • OCaml
ception UNIFY
  • C++
    struct unify_exception{};

solve

次から標準形 $T$ を求めます。

\left\{ s_{1} =^{?} t_{1},\ \dots, \ s_{n} =^{?} t_{n} \right\}

C++ 実装では効率化のために不変のリストに対して iterator を用いています。

  • OCaml
let rec solve ttlist_and_subst =
  match ttlist_and_subst with
    | ([], s)
      -> s
    | ((V x, t) :: rest, s)
      -> if V x = t then solve (rest, s) else elim x t rest s
    | ((t, V x) :: rest, s)
      -> elim x t rest s
    | ((T (f, ts), T (g, us)) :: rest, s)
      -> if f = g then solve (zip(ts, us) @ rest, s) else raise UNIFY
  • C++
    static subst solve(typename ids::const_iterator tt_first, typename ids::const_iterator tt_last, const subst &s){
        if(tt_first == tt_last){
            return s;
        }else if(tt_first->first.which() == term::type_variable){
            if(tt_first->second.which() == term::type_variable && boost::get<vname>(tt_first->first) == boost::get<vname>(tt_first->second)){
                return solve(tt_first + 1, tt_last, s);
            }else{
                return elim(boost::get<vname>(tt_first->first), tt_first->second, tt_first + 1, tt_last, s);
             }
        }else if(tt_first->second.which() == term::type_variable){
            return elim(boost::get<vname>(tt_first->second), tt_first->first, tt_first + 1, tt_last, s);
        }else{
            const term_list &ts = boost::get<term_list>(tt_first->first), &us = boost::get<term_list>(tt_first->second);
            if(ts.str == us.str){
                ids ts_us_rest = zip(ts.list, us.list);
                ts_us_rest.insert(ts_us_rest.end(), tt_first + 1, tt_last);
                return solve(ts_us_rest.begin(), ts_us_rest.end(), s);
            }else{
                throw unify_exception();
            }
        }
    }

elim

続いて、不要な規則を消去する関数です。こちらも不要なリストの生成が発生しないよう、 C++ では必要な箇所に応じて iterator を用いています。

  • OCaml
and elim x t rest s =
  let xt = lift (add_subst empty_subst x t) in
    if occurs x t then raise UNIFY
    else solve(map (fun (t1, t2) -> (xt t1, xt t2)) rest,
                 (x, t) :: (map (fun (y, u) -> (y, xt u)) s))
  • C++
    static subst elim(const vname &x, const term &t, typename ids::const_iterator rest_first, typename ids::const_iterator rest_last, const subst &s){
        auto xt = [&x, &t](const term &u){
            subst w = { std::make_pair(x, t) };
            return lift(w, u);
        };
        if(occurs(x, t)){
            throw unify_exception();
        }else{
            std::function<std::pair<term, term>(const std::pair<term, term>&)> f = [&xt](const std::pair<term, term> &tt){
                return std::make_pair(xt(tt.first), xt(tt.second));
            };
            ids mapped_rest = map(f, rest_first, rest_last);
            std::function<vname_term(const vname_term&)> g = [&xt](const vname_term &yu){
                return std::make_pair(yu.first, xt(yu.second));
            };
            return solve(mapped_rest.begin(), mapped_rest.end(), add_subst(map(g, s), x, t));
        }
    }

unify

規則の一意化を行います。

  • OCaml
let unify (t1, t2) = solve ([(t1, t2)], [])
  • C++
    static subst unify(const std::pair<term, term> &tt){
        ids v = { tt };
        return solve(v.begin(), v.end(), subst());
    }

match 、 pattern_match

項同士のマッチングを行います。こちらも、無駄を削減するために C++ 実装では iterator を用いています。

  • OCaml
let rec matchs ttlist s =
  match ttlist with
    | []
      -> s
    | (V x, t) :: rest
      -> if indom x s then if app s x = t then matchs rest s else raise UNIFY
         else matchs rest (add_subst s x t)
    | (t, V x) :: rest
      -> raise UNIFY
    | (T (f, ts), T (g, us)) :: rest
      -> if f = g then matchs (zip (ts, us) @ rest) s else raise UNIFY

let pattern_match pat obj = matchs [(pat, obj)] empty_subst
  • C++
    static subst matchs(typename ids::const_iterator tt_first, typename ids::const_iterator tt_last, const subst &s){
        if(tt_first == tt_last){
            return s;
        }else{
            const term &t1 = tt_first->first, &t2 = tt_first->second;
            if(t1.which() == term::type_variable){
                const vname &x = boost::get<vname>(t1);
                const term &t = t2;
                if(indom(x, s)){
                    if(app(s, x) == t){
                        return matchs(tt_first + 1, tt_last, s);
                    }else{ throw unify_exception(); }
                }else{
                    return matchs(tt_first + 1, tt_last, add_subst(s, x, t));
                }
            }else if(t2.which() == term::type_variable){
                throw unify_exception();
            }else{
                const term_list &ts = boost::get<term_list>(t1), &us = boost::get<term_list>(t2);
                if(ts.str == us.str){
                    ids ts_us_rest = zip(ts.list, us.list);
                    ts_us_rest.insert(ts_us_rest.end(), tt_first + 1, tt_last);
                    return matchs(ts_us_rest.begin(), ts_us_rest.end(), s);
                }else{ throw unify_exception(); }
            }
        }
    }

    static subst pattern_match(const term &pat, const term &obj){
        ids v = { std::make_pair(pat, obj) };
        return matchs(v.begin(), v.end(), subst());
    }

rewrite 、 norm

書き換えを行い、標準形を求めます。失敗した場合は例外を送出します。

  • OCaml
exception NORM

let rec rewrite ttlist t =
  try try_rewrite ttlist t with
    | UNIFY -> retry_rewrite ttlist t

and try_rewrite ttlist t =
  match ttlist with
    | []               -> raise NORM
    | ((l, r) :: rest) -> lift (pattern_match l t) r

and retry_rewrite ttlist t =
  match ttlist with
    | []               -> raise NORM
    | ((l, r) :: rest) -> rewrite rest t

let rec norm r t =
  match t with
    | V x       -> V x
    | T (f, ts) -> inner_norm r f ts

and inner_norm r f ts =
  let u = T (f, map (norm r) ts) in
    try norm r (rewrite r u) with
      | NORM -> u
  • C++
    struct norm_exception{};

    static term rewrite(typename ids::const_iterator tt_first, typename ids::const_iterator tt_last, const term &t){
        for(; tt_first != tt_last; ++tt_first){
            try{
                const term &l = tt_first->first, &r = tt_first->second;
                return lift(pattern_match(l, t), r);
            }catch(unify_exception){
                continue;
            }
        }
        throw norm_exception();
    }

    static term norm(const ids &r, const term &t){
        if(t.which() == term::type_variable){
            return t;
        }else{
            const term_list &ts = boost::get<term_list>(t);
            std::function<term(const term&)> f = [&r](const term &a){ return norm(r, a); };
            term u = make_term_list(ts.str, map(f, ts.list));
            try{
                return norm(r, rewrite(r.begin(), r.end(), u));
            }catch(norm_exception){
                return u;
            }
        }
    }

order

項比較時のオーダーの記述を行います。
Greater 、 Equal 、 Less (NGE) に分けられます。
書き換え対象の項の構造は定数項、変数項を葉とした木構造になります。 lex 関数は純粋な列を比較し、 lpo 関数は木構造同士の項を深さ優先探査で比較していきます。

  • OCaml
type order =
  | GR
  | EQ
  | NGE

let int_to_order a = if a > 0 then GR else if a = 0 then EQ else NGE

let rec lex ord alpha_list_and_beta_list =
  match alpha_list_and_beta_list with
    | ([], [])           -> EQ
    | (x :: xs, y :: ys) -> inner_lex (ord (x, y)) ord xs ys
    | (_, _)             -> raise INVALID_ARGUMENT

and inner_lex o ord xs ys =
  match o with
    | GR  -> GR
    | EQ  -> lex ord (xs, ys)
    | NGE -> NGE

let rec lpo ord st =
  match st with
    | (s, V x)
      -> if s = V x then EQ
         else if occurs x s then GR else NGE
    | (V _, T _)
      -> NGE
    | (T (f, ss), T (g, ts))
      -> if forall (fun si -> lpo ord (si, T (g, ts)) = NGE) ss
         then inner_lpo (ord (f, g)) ord (T (f, ss)) ss ts
         else GR

and inner_lpo o ord s ss ts =
  match o with
    | GR
      -> if forall (fun ti -> lpo ord (s, ti) = GR) ts
         then GR else NGE
    | EQ
      -> if forall (fun ti -> lpo ord (s, ti) = GR) ts
      then lex (lpo ord) (ss, ts)
      else NGE
    | NGE
      -> NGE

let lpo_functor (t, u) = (lpo (fun (a, b) -> int_to_order (String.compare a b))) (t, u)

  • C++
    enum class order : int{
        nge = -1,
        eq = 0,
        gr = 1
    };

    template<class X, class Iter>
    static order lex(std::function<order(const X&, const X&)> f, Iter iter, Iter a_end, Iter jter, Iter b_end){
        for(; iter != a_end && jter != b_end; ++iter, ++jter){
            order ord = f(*iter, *jter);
            if(ord != order::eq){
                return ord;
            }
        }
        if(iter == a_end && jter == b_end){
            return order::eq;
        }
        if(iter != a_end && jter ==b_end){
            return order::gr;
        }else{
            return order::nge;
        }
    }

    static order lpo(std::function<order(const type&, const type&)> ord, const term &s, const term &t){
        if(t.which() == term::type_variable){
            if(s == t){
                return order::eq;
            }else if(occurs(boost::get<vname>(t), s)){
                return order::gr;
            }else{
                return order::nge;
            }
        }else if(s.which() == term::type_variable && t.which() == term::type_term_list){
            return order::nge;
        }else{
            const term_list &ss = boost::get<term_list>(s), &ts = boost::get<term_list>(t);
            std::function<bool(const term&)> f = [&ord, &t](const term &si){ return lpo(ord, si, t) == order::nge; };
            if(forall(f, ss.list)){
                std::function<bool(const term&)> g = [&ord, &s](const term &a){ return lpo(ord, s, a) == order::gr; };
                switch(ord(ss.str, ts.str)){
                    case order::gr:                            
                        if(forall(g, ts.list)){
                            return order::gr;
                        }else{
                            return order::nge;
                        }

                    case order::eq:
                        if(forall(g, ts.list)){
                            std::function<order(const term&, const term&)> h = [&ord](const term &a, const term &b){ return lpo(ord, a, b); };
                            return lex(h, ss.list.begin(), ss.list.end(), ts.list.begin(), ts.list.end());
                        }else{
                            return order::nge;
                        }

                    case order::nge:
                        return order::nge;
                }
            }
            return order::gr;
        }
    }

    static order lpo_functor(const term &t, const term &u){
        std::function<order(const type&, const type&)> f = [](const type &a, const type &b){
            bool p = Less()(a, b);
            bool q = Less()(b, a);
            if(!p && !q){
                return order::eq;
            }else if(q){
                return order::gr;
            }else{
                return order::nge;
            }
        };
        return lpo(f, t, u);
    }

    class term_comparetor{
    public:
        bool operator()(const term &t, const term &u) const{
            return lpo_functor(t, u) == order::nge;
        }
    };

rename

項の中に存在する変数の番号を変更します。

  • OCaml
let rec rename n =
  function (V (x, i))  -> V (x, i + n)
         | (T (f, ts)) -> T (f, map (rename n) ts)

let rec maxs =
  function (i :: is) -> max i (maxs is)
         | []        -> 0
  • C++
    static term rename(int n, const term &t){
        if(t.which() == term::type_variable){
            term r = t;
            boost::get<vname>(r).i += n;
            return r;
        }else{
            const term_list &ts = boost::get<term_list>(t);
            std::function<term(const term&)> f = [n](const term &a){ return rename(n, a); };
            return make_term_list(ts.str, map(f, ts.list));
        }
    }

maxs 、 maxindex

それぞれリスト中にあるintから最大値を得る関数、項中にある変数の番号から最大値を得る関数です。

  • OCaml
let rec maxs =
  function (i :: is) -> max i (maxs is)
         | []        -> 0

let rec maxindex =
  function (V (x, i))  -> i
         | (T (_, ts)) -> maxs (map maxindex ts)
  • C++
    static int maxs(const std::vector<int> &s){
        int n = 0;
        for(auto &a : s){
            n = std::max(a, n);
        }
        return n;
    }

    static int maxindex(const term &t){
        if(t.which() == term::type_variable){
            return boost::get<vname>(t).i;
        }else{
            int n = 0;
            for(auto &u : boost::get<term_list>(t).list){
                n = std::max(maxindex(u), n);
            }
            return n;
        }
    }

critical_pairs

合流する規則を追加するために、危険対を得ます。変更の加わらない項のリストについては他の関数と同様に C++ では iterator を使って不要なコピーを防ぐコードにしました。
λ式を引数に取ったり関数から返したりするので少しややこしいです。

  • OCaml
let ccp c tr l2r2 =
  match tr with
    | (t, r) -> match l2r2 with
      | (l2, r2) ->
          try
            [(lift (unify (t, l2)) r, lift (unify (t, l2)) (c r2))]
          with
            | UNIFY -> []

let rec ccps ttlist (l, r) =
  let rec cps c =
    function (V _, _)       -> []
           | (T (f, ts), r) -> concat (map (ccp c (T (f, ts), r)) ttlist) @ (inner_cps c (f, [], ts, r))

    and inner_cps c =
      function (_, _, [], _)         -> []
             | (f, ts0, t :: ts1, r)
               -> let cf s = c (T (f, ts0 @ [s] @ ts1)) in
                    (cps cf (t, r)) @ (inner_cps c (f, ts0 @ [t], ts1, r))

    and m t = rename (maxs (map (fun (ts, us) -> max (maxindex ts) (maxindex us)) ttlist) + 1) t

  in cps (fun t -> t) (m l, m r)

let critical_pairs2 r1 r2 = concat (map (ccps r1) r2)

let critical_pairs r = critical_pairs2 r r
  • C++
    static ids ccp(
        std::function<term(const term&)> c,
        const term &t,
        const term &r,
        const term &l2,
        const term &r2
    ){
        try{
            return ids({std::make_pair(
                lift(unify(std::make_pair(t, l2)), r),
                lift(unify(std::make_pair(t, l2)), c(r2))
            )});
        }catch(unify_exception){
            return ids();
        }
    }

    static ids ccps(const ids &ttlist, const term &l, const term &r){
        std::function<ids(std::function<term(const term&)>, const term&, const term&)> cps;
        cps = [&](std::function<term(const term&)> c, const term &ts_, const term &r){
            if(ts_.which() == term::type_variable){
                return ids();
            }else{
                const term_list &ts = boost::get<term_list>(ts_);
                std::function<ids(const std::pair<term, term>&)> f = [&](const std::pair<term, term> &l2r2){
                    return ccp(c, ts_, r, l2r2.first, l2r2.second);
                };
                std::function<
                    ids(
                        std::function<term(const term&)>,
                        const type&,
                        const term_list&,
                        typename std::vector<term>::const_iterator,
                        typename std::vector<term>::const_iterator,
                        const term&
                    )
                > inner_cps = [&](
                    std::function<term(const term&)> c,
                    const type &f,
                    const term_list &ts0,
                    typename std::vector<term>::const_iterator ts1_first,
                    typename std::vector<term>::const_iterator ts1_last,
                    const term &r
                ){
                    if(ts1_first == ts1_last){
                        return ids();
                    }else{
                        std::function<term(const term&)> cf = [&](const term &s){
                            std::vector<term> v = ts0.list;
                            v.reserve(v.size() + (ts1_last - ts1_first));
                            v.push_back(s);
                            v.insert(v.end(), ts1_first + 1, ts1_last);
                            return c(make_term_list(f, v));
                        };
                        term_list ts0_t = ts0;
                        ts0_t.list.push_back(*ts1_first);
                        ids
                            tmp1 = cps(cf, *ts1_first, r),
                            tmp2 = inner_cps(c, f, ts0_t, ts1_first + 1, ts1_last, r);
                        tmp1.insert(tmp1.end(), tmp2.begin(), tmp2.end());
                        return tmp1;
                    }
                };
                std::function<ids(const std::pair<term, term>&)> wrapped_ccp = [&](const std::pair<term, term> &l2r2){
                    return ccp(c, ts, r, l2r2.first, l2r2.second);
                };
                std::vector<ids> ccp_ttlist_result = map(wrapped_ccp, ttlist);
                {
                    ids inner_cps_result = inner_cps(c, ts.str, term_list(), ts.list.begin(), ts.list.end(), r);
                    ccp_ttlist_result.push_back(inner_cps_result);
                }
                return concat(ccp_ttlist_result);
            }
        };
        auto m = [&](const term &t){
            int mtu = 0;
            for(auto &p : ttlist){
                mtu = std::max(maxindex(p.first), mtu);
                mtu = std::max(maxindex(p.second), mtu);
            }
            return rename(mtu + 1, t);
        };
        std::function<term(const term&)> identity = [](const term &a){ return a; };
        return cps(identity, m(l), m(r));
    }

    static ids critical_pairs2(const ids &r1, const ids &r2){
        ids ttlist;
        for(auto &iter : r2){
            ids tt = ccps(r1, iter.first, iter.second);
            ttlist.insert(ttlist.end(), tt.begin(), tt.end());
        }
        return ttlist;
    }

    static ids critical_pairs(const ids &r){
        return critical_pairs2(r, r);
    }

    struct context_element{
        type str;
        std::vector<term> ts, us;
    };

add_rule

書き換え規則を表す ids に新たに標準形の書き換え規則を追加します。 C++ 実装の内部のsimpl関数は、再起関数からループへ変換できませんでした。

  • OCaml
let add_rule (l, r, ids_e, ids_s, ids_r) =
  let rec simpl triple_ids =
    match triple_ids with
      | ([], e', r')
        -> (e', r')
      | ((g, d) :: u, e', u')
        -> let g' = norm [(l, r)] g
             in if g' = g
             then let d' = norm ((l, r) :: ids_r @ ids_s) d
                  in simpl (u, e', (g, d') :: u')
             else simpl (u, (g', d) :: ids_e, u')
  in let (e', s') = simpl (ids_s, ids_e, [])
  in let (e'', r') = simpl (ids_r, e', [])
  in (e'', (l, r) :: s', r')
  • C++
    static std::tuple<ids, ids, ids> add_rule(const term &l,  const term &r, typename ids::const_iterator ids_e_first, typename ids::const_iterator ids_e_last, const ids &ids_s, const ids &ids_r){
        std::function<
            std::tuple<ids, ids>(
                typename ids::const_iterator,
                typename ids::const_iterator,
                typename ids::const_iterator,
                typename ids::const_iterator,
                const ids&
            )
        > simpl = [&](
            typename ids::const_iterator u_first,
            typename ids::const_iterator u_last,
            typename ids::const_iterator e_prime_first,
            typename ids::const_iterator e_prime_last,
            const ids &u_prime
        ){
            if(u_first == u_last){
                return make_tuple(ids(e_prime_first, e_prime_last), u_prime);
            }else{
                const term &g = u_first->first, &d = u_first->second;
                ids lr = { std::make_pair(l, r) };
                term g_prime = norm(lr, g);
                if(g_prime == g){
                    term d_prime;
                    {
                        ids lr_r_s;
                        lr_r_s.reserve(ids_r.size() + ids_s.size() + 1);
                        lr_r_s.push_back(std::make_pair(l, r));
                        lr_r_s.insert(lr_r_s.end(), ids_r.begin(), ids_r.end());
                        lr_r_s.insert(lr_r_s.end(), ids_s.begin(), ids_s.end());
                        d_prime = norm(lr_r_s, d);
                    }
                    ids g_d_prime_u_prime;
                    g_d_prime_u_prime.reserve(u_prime.size() + 1);
                    g_d_prime_u_prime.push_back(std::make_pair(g, d_prime));
                    g_d_prime_u_prime.insert(g_d_prime_u_prime.end(), u_prime.begin(), u_prime.end());
                    return simpl(u_first + 1, u_last, e_prime_first, e_prime_last, g_d_prime_u_prime);
                }else{
                    ids g_prime_d_e;
                    g_prime_d_e.reserve(ids_e_last - ids_e_first + 1);
                    g_prime_d_e.push_back(std::make_pair(g_prime, d));
                    g_prime_d_e.insert(g_prime_d_e.end(), ids_e_first, ids_e_last);
                    return simpl(u_first + 1, u_last, g_prime_d_e.begin(), g_prime_d_e.end(), u_prime);
                }
            }
        };
        std::tuple<ids, ids> e_prime_s_prime = simpl(ids_s.begin(), ids_s.end(), ids_e_first, ids_e_last, ids());
        ids &e_prime = std::get<0>(e_prime_s_prime), &s_prime = std::get<1>(e_prime_s_prime);
        std::tuple<ids, ids> e_wprime_r_prime = simpl(ids_r.begin(), ids_r.end(), e_prime.begin(), e_prime.end(), ids());
        ids &e_wprime = std::get<0>(e_wprime_r_prime), r_prime = std::get<1>(e_wprime_r_prime);
        ids lr_s_prime;
        lr_s_prime.reserve(s_prime.size() + 1);
        lr_s_prime.push_back(std::make_pair(l, r));
        lr_s_prime.insert(lr_s_prime.end(), s_prime.begin(), s_prime.end());
        return make_tuple(e_wprime, lr_s_prime, r_prime);
    }

fail

完備化に失敗したときに送出される例外です。書き換え規則の比較時に Less then 、Greater then のどちらも成り立たないときの他、 orient 関数内でも送出されます。

  • OCaml
exception FAIL
  • C++
struct fail_exception{};

orient

Functor を返す関数です。この関数についても、 iterator を使った効率化を施しました。

  • OCaml
let orient ord vord =
  let rec ori triple_ids =
    match triple_ids with
      | ([], ids_s, ids_r)
        -> (ids_s, ids_r)
      | ((s, t) :: ids_e, ids_s, ids_r)
        -> let s' = norm (ids_r @ ids_s) s
           and t' = norm (ids_r @ ids_s) t
           in if s' = t' then ori (ids_e, ids_s, ids_r)
                         else if ord (s', t') = GR then ori (add_rule (s', t', ids_e, ids_s, ids_r))
                         else if ord (t', s') = GR then ori (add_rule (t', s', ids_e, ids_s, ids_r))
                         else raise FAIL
  in ori
  • C++
    static std::function<std::pair<ids, ids>(typename ids::const_iterator, typename ids::const_iterator, const ids&, const ids&)> orient(std::function<order(const term&, const term&)> ord){
        std::function<std::pair<ids, ids>(typename ids::const_iterator, typename ids::const_iterator, const ids&, const ids&)> ori = [&](typename ids::const_iterator e_first, typename ids::const_iterator e_last, const ids &ids_s, const ids &ids_r){
            if(e_first == e_last){
                return std::make_pair(ids_s, ids_r);
            }else{
                const term &s = e_first->first, &t = e_first->second;
                ids rs;
                rs.reserve(ids_r.size() + ids_s.size());
                rs.insert(rs.end(), ids_r.begin(), ids_r.end());
                rs.insert(rs.end(), ids_s.begin(), ids_s.end());
                term s_prime = norm(rs, s), t_prime = norm(rs, t);
                if(ord(s_prime, t_prime) == order::eq){
                    return ori(e_first + 1, e_last, ids_s, ids_r);
                }else if(ord(s_prime, t_prime) == order::gr){
                    std::tuple<ids, ids, ids> tmp = add_rule(s_prime, t_prime, e_first + 1, e_last, ids_s, ids_r);
                    return ori(std::get<0>(tmp).begin(), std::get<0>(tmp).end(), std::get<1>(tmp), std::get<2>(tmp));
                }else if(ord(t_prime, s_prime) == order::gr){
                    std::tuple<ids, ids, ids> tmp = add_rule(t_prime, s_prime, e_first + 1, e_last, ids_s, ids_r);
                    return ori(std::get<0>(tmp).begin(), std::get<0>(tmp).end(), std::get<1>(tmp), std::get<2>(tmp));
                }
                throw fail_exception();
            }
        };
        return ori;
    }

size

項の大きさを得る関数です。項を木構造と見做したとき、 node にあたる variableterm_list の数をそれぞれカウントします。

  • OCaml
let rec size t =
  match t with
    | (V _)       -> 1
    | (T (_, ts)) -> sizes ts + 1
  and sizes ts =
    match ts with
      | []        -> 0
      | (t :: ts) -> size t + sizes ts
  • C++
    static std::size_t size(const term &t){
        if(t.which() == term::type_variable){
            return 1;
        }else{
            const term_list &ts = boost::get<term_list>(t);
            std::size_t s = 1;
            for(auto &iter : ts.list){
                s += size(iter);
            }
            return s;
        }
    }

min_rule

上記の size 関数を利用し、小さな書き換え規則を比較して得ます。

  • OCaml
let rec min_rule tt_i_ids_ids =
  match tt_i_ids_ids with
    | (rl, n, [], r')
      -> (rl, r')
    | (rl, n, (l, r) :: rest, r')
      -> let m = size l + size r
           in if m < n then min_rule ((l, r), m, rest, rl :: r')
                       else min_rule (rl, n, rest, (l, r) :: r')
  • C++
    static std::tuple<const term*, const term*, ids> min_rule(const term &t, const term &u, std::size_t n, typename ids::const_iterator r_first, typename ids::const_iterator r_last, const ids &r_prime){
        if(r_first == r_last){
            return make_tuple(&t, &u, r_prime);
        }else{
            const term &l = r_first->first, &r = r_first->second;
            std::size_t m = size(l) + size(r);
                ids nr;
                nr.reserve(r_prime.size() + 1);
            if(m < n){
                nr.push_back(std::make_pair(t, u));
                nr.insert(nr.end(), r_prime.begin(), r_prime.end());
                return min_rule(l, r, m, r_first + 1, r_last, nr);
            }else{
                nr.push_back(std::make_pair(l, r));
                nr.insert(nr.end(), r_prime.begin(), r_prime.end());
                return min_rule(t, u, n, r_first + 1, r_last, nr);
            }
        }
    }

choose

min_rule で比較して得た小さな規則を順に得ます。

  • OCaml
let choose ttlist =
  match ttlist with
    | ((l, r) :: rest) -> min_rule ((l, r), size l + size r, rest, [])
    | _                -> raise INVALID_ARGUMENT
  • C++
    static std::tuple<const term*, const term*, ids> choose(typename ids::const_iterator r_first, typename ids::const_iterator r_last){
        return min_rule(r_first->first, r_first->second, size(r_first->first) + size(r_first->second), r_first + 1, r_last, ids());
    }

complete

Knuth-Bendix の完備化アルゴリズム本体です。 OCaml では内部に compl 関数を記述して再帰させループしていますが、 C++ では単純な while ループに置き換えています。
C++ 版では第 3 引数に試行回数を指定します。この回数に到達するか、完備化が完了、あるいは失敗するまでアルゴリズムは繰り返されます。 0 を指定すると回数の上限の設定を無効にし、成功か失敗するまで停止しなくなります。

  • OCaml
let complete ord vord ids_e =
  let rec compl esr =
    match orient ord vord esr with
      | ([], r')
        -> r'
      | (s', r')
        -> let (rl, s'') = choose s'
           in let cps = critical_pairs2 [rl] r' @
                        critical_pairs2 r' [rl] @
                        critical_pairs2 [rl] [rl]
            in compl (cps, s'', rl :: r')
  in compl (ids_e, [], [])
  • C++
    static ids complete(std::function<order(const term&, const term&)> ord, const ids &e, std::size_t break_count = 0){
        std::function<ids(ids, ids, ids)> inner_completion = [&](ids ids_e, ids ids_s, ids ids_r){
            std::pair<ids, ids> sr;
            std::size_t count = 0;
            while(break_count == 0 || count < break_count){
                std::function<std::pair<ids, ids>(typename ids::const_iterator, typename ids::const_iterator, const ids&, const ids&)> f = orient(ord);
                sr = f(ids_e.begin(), ids_e.end(), ids_s, ids_r);
                if(sr.first.empty()){
                    break;
                }else{
                    const ids &s_prime = sr.first, &r_prime = sr.second;
                    std::tuple<const term*, const term*, ids> rl_s_wprime = choose(s_prime.begin(), s_prime.end());
                    ids cps, cps_tmp, rl = { std::make_pair(*std::get<0>(rl_s_wprime), *std::get<1>(rl_s_wprime)) };
                    cps = critical_pairs2(rl, r_prime);
                    cps_tmp = critical_pairs2(r_prime, rl);
                    cps.insert(cps.begin(), cps_tmp.begin(), cps_tmp.end());
                    cps_tmp = critical_pairs2(rl, rl);
                    cps.insert(cps.begin(), cps_tmp.begin(), cps_tmp.end());
                    cps_tmp = ids();
                    ids rl_r_prime;
                    rl_r_prime.reserve(r_prime.size() + 1);
                    rl_r_prime.push_back(std::make_pair(*std::get<0>(rl_s_wprime), *std::get<1>(rl_s_wprime)));
                    rl_r_prime.insert(rl_r_prime.begin(), r_prime.begin(), r_prime.end());
                    ids_e = std::move(cps);
                    ids_s = std::move(std::get<2>(rl_s_wprime));
                    ids_r = std::move(rl_r_prime);
                }
                ++count;
            }
            return sr.second;
        };
        return inner_completion(e, ids(), ids());
    }

動作させてみる

出力関数

動作結果を得るために、 term を出力する関数を記述します。
vname が変数、空の term_list が定数、引数を持っている term_list が関数です。

#include <iostream>
#include <string>
#include <memory>

template<class TRS>
static void print_term(const typename TRS::term &t){
    switch(t.which()){
        case TRS::term::type_variable:
            std::cout << boost::get<typename TRS::vname>(t).str;
            break;

        case TRS::term::type_term_list:
            std::cout << boost::get<typename TRS::term_list>(t).str;
            if(!TRS::null(boost::get<typename TRS::term_list>(t).list)){
                std::cout << "(";
            }
            std::function<void(const typename TRS::term&)> f = [](const typename TRS::term &t){
                print_term<TRS>(t);
            };
            std::function<void(typename std::vector<typename TRS::term>::const_iterator, typename std::vector<typename TRS::term>::const_iterator)>
                g = [](typename std::vector<typename TRS::term>::const_iterator beg, typename std::vector<typename TRS::term>::const_iterator end){
                     if(beg != end){ std::cout << ", "; }
                };
            TRS::allapp(f, g, boost::get<typename TRS::term_list>(t).list);
            if(!TRS::null(boost::get<typename TRS::term_list>(t).list)){
                std::cout << ")";
            }
            break;
    }
}

テストとして群を完備化してみます。群は以下の3つの公理から成り立っています。

  • 結合律
    $(x * y) * z = x * (y * z)$

  • 逆元の存在
    $x * i(x) = e$

  • 単位元 e の存在
    $e * x = x$

int main(){
    // TRS を行う型を定義する.
    using trs = term_rewriting_system<std::string>;
    // 関数 f(x, y), 上記の例の演算子 "*" に相当する.
    auto f = [&](const trs::term &a, const trs::term &b){ return trs::make_function("f", { a, b }); };
    // 逆関数.
    auto i = [&](const trs::term &a){ return trs::make_function("i", { a }); };
    // 単位元.
    auto e = trs::make_constant("e");
    // 変数を返す lambda 式.
    auto var = [&](const std::string &str){ return trs::make_variable(str); };

    // 結合律による等式.
    std::pair<trs::term, trs::term>
        rule_1 = std::make_pair(f(f(var("x"), var("y")), var("z")), f(var("x"), f(var("y"), var("z"))));
    // 逆元の存在による等式.
    std::pair<trs::term, trs::term>
        rule_2 = std::make_pair(f(i(var("x")), var("x")), e);
    // 単位元の存在による等式.
    std::pair<trs::term, trs::term>
        rule_3 = std::make_pair(f(e, var("x")), var("x"));
    // 上記3つの書き換え規則をまとめる.
    trs::ids rules = { rule_1, rule_2, rule_3 };
    // 完備化する.
    trs::ids result = trs::complete(trs::lpo_functor, rules);

    // 書き換え規則を出力する.
    auto print_rules = [&](const trs::ids &rules){
        for(auto &iter : rules){
            print_term<trs>(iter.first);
            std::cout << " => ";
            print_term<trs>(iter.second);
            std::cout << std::endl;
        }
    };

    // 完備化する前の書き換え規則を出力する.
    std::cout << "group:" <<  std::endl;
    print_rules(rules);
    std::cout << std::endl;

    // 完備化した後の書き換え規則を出力する.
    std::cout << "completed group:" <<  std::endl;
    print_rules(result);
    std::cout << std::endl;

    return 0;
}

出力は次の通りです。

group:
f(f(x, y), z) => f(x, f(y, z))
f(i(x), x) => e
f(e, x) => x

completed group:
f(x, i(x)) => e
i(i(x)) => x
f(x, e) => x
f(i(x), f(x, z)) => z
f(e, x) => x
f(i(x), x) => e
f(f(x, y), z) => f(x, f(y, z))
i(e) => e
f(x, f(i(x), z)) => z
i(f(x, y)) => f(i(y), i(x))

しっかり完備化できていることが分かります。

参考

当ソースコード
研究の紹介:項書き換えシステム入門

11
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
9