Variadic Template を使って switch を使ったテンプレート関数呼び出しを除去する

  • 13
    Like
  • 3
    Comment
More than 1 year has passed since last update.

この記事は C++ Advent Calendar 2015 の 2日目の記事です.
前の日は ignis_fatuus さんの Metashellを使ったC++メタプログラミングの入門とデバッグ です.
次の日は Riyaaaaa さんの C++ AMPによるGPGPU入門 です.

この記事で述べること

ありがちなテンプレート関数を呼び分けるダサい処理

switch (type) {
case Enum::INT: return DoSomething<int>();
case Enum::DOUBLE: return DoSomething<double>();
// 他の様々な型が続く...
}

を, boost::apply_visitor を一般化する感じで,

type_list::apply(DoSomething(), type);

と書けるようにしてみました.

背景

C++ で動的ディスパッチが必要なプログラムを書いている人は, しばしば上記のような switch 文等による呼び分けが必要になることがあると思います.
小さなプログラムならまだ耐えられますが, 規模が大きくなってきて, switch がそこかしこに散らばっている状態になるとかなりつらいです.
例えば DSL を実装していて, データ型がいろいろあって, 処理がいろいろあって, といった状態になると, 仕様変更で型を追加するだけでもかなり影響範囲の大きい仕事になります.

サンプル

本稿では, 整数の四則演算を解くプログラムを考えていきたいと思います. 書かれたコードを試すには C++14 対応コンパイラをご用意ください. Coliru で動作確認済みです. この記事に書かれたコードの一部または全てを用いて何か損害がでても当方はなんの責任も負いません.

とりあえず初期実装

static int calculate(char op, int lhs, int rhs)
{
    switch (op) {
    case '+':
        return lhs + rhs;
    case '-':
        return lhs - rhs;
    case '*':
        return lhs * rhs;
    case '/':
        return (rhs != 0) ? lhs / rhs : throw std::invalid_argument("zero division");
    default:
        break;
    }
    throw std::invalid_argument("invalid operator");
}

演算子とオペランドを受け取って計算結果を返します. 簡単ですね. これに演算子を追加していきます.
case の中がどの分岐でもほとんど一緒なので, テンプレートを使いたくなりますね.

     switch (op) {
     case '+':
-        return lhs + rhs;
+        return std::plus<int>()(lhs, rhs);
     case '-':
-        return lhs - rhs;
+        return std::minus<int>()(lhs, rhs);
     case '*':
-        return lhs * rhs;
+        return std::multiplies<int>()(lhs, rhs);
     case '/':
-        return (rhs != 0) ? lhs / rhs : throw std::invalid_argument("zero division");
+        return (rhs != 0) ? std::divides<int>()(lhs, rhs)
+                          : throw std::invalid_argument("zero division");
     default:
         break;
     }

ここに剰余算の処理を追加しようとした時に, 普通なら case '%' を追加するところですが, すでに使われている関数を変更するのはなんか嫌です. 既存のコードを変更せずに機能を拡張していきましょう.

Switch 文を除去

します.

#include <cstdint>
#include <functional>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>

// T から const/volatile/reference を除いた型
template <typename T>
using remove_cv_reference_t =
    typename std::remove_cv<typename std::remove_reference<T>::type>::type;
// T0,...,Ts の中で T が出現する最初の位置
template <size_t I, typename T, typename T0, typename... Ts>
struct index_of_impl {
    static const size_t value = (std::is_same<T, T0>::value)
                                    ? I
                                    : index_of_impl<I + 1, T, Ts...>::value;
};
// T0,...,Ts の中で T が出現する最初の位置
template <size_t I, typename T, typename T0>
struct index_of_impl<I, T, T0> {
    static const size_t value =
        (std::is_same<T, T0>::value) ? I : static_cast<size_t>(-1);
};
// T0,...,Ts の中で I 番目の型
template <size_t I, typename T0, typename... Ts>
struct at_impl {
    using type =
        typename std::conditional<I == 0,
                                  T0,
                                  typename at_impl<I - 1, Ts...>::type>::type;
};
// T0,...,Ts の中で I 番目の型
template <size_t I, typename T0>
struct at_impl<I, T0> {
    using type = typename std::conditional<I == 0, T0, void>::type;
};
// 型のリスト
template <typename... Ts>
struct type_list_t {
    // 型数
    static constexpr size_t size() { return sizeof...(Ts); };
    // T が最初に現れる位置
    template <typename T>
    static constexpr size_t index_of()
    {
        return index_of_impl<0,
                             remove_cv_reference_t<T>,
                             remove_cv_reference_t<Ts>...>::value;
    }
    // I 番目の型
    template <size_t I>
    using at_t = typename at_impl<I, Ts...>::type;
    // idx 番目の型 T について, f.operator()<T>() を実行する
    template <typename F, typename Index>
    static auto apply(F&& f, Index idx) -> decltype(auto)
    {
        using R = decltype(f.template operator()<at_t<0>>());  // 戻り値の型
        static std::make_index_sequence<size()> seq;  // 整数シーケンス
        return apply<R>(seq, std::forward<F>(f), static_cast<int>(idx));
    }

private:
    // idx 番目の型 T について, f.operator()<T>() を実行する
    template <typename R, typename F, size_t... Is>
    static R apply(std::index_sequence<Is...>, F&& f, int idx)
    {
        using func_t = decltype(&apply<R, F, at_t<0>>);  // 関数ポインタの型
        // 関数ポインタテーブルを生成
        // idx 番目の関数ポインタは apply<R, F, at_t<idx>> である
        static func_t func_table[] = {&apply<R, F, at_t<Is>>...};
        return func_table[idx](std::forward<F>(f));
    }
    // 型 T について, f.operator()<T>() を実行する
    template <typename R, typename F, typename T>
    static R apply(F&& f)
    {
        return f.template operator()<T>();
    }
};

// 加算
struct plus_t {
    template <typename T1, typename T2>
    auto operator()(T1&& lhs, T2&& rhs) const -> decltype(auto)
    {
        return std::forward<T1>(lhs) + std::forward<T2>(rhs);
    }
};
// 減算
struct minus_t {
    template <typename T1, typename T2>
    auto operator()(T1&& lhs, T2&& rhs) const -> decltype(auto)
    {
        return std::forward<T1>(lhs) - std::forward<T2>(rhs);
    }
};
// 乗算
struct multiply_t {
    template <typename T1, typename T2>
    auto operator()(T1&& lhs, T2&& rhs) const -> decltype(auto)
    {
        return std::forward<T1>(lhs) * std::forward<T2>(rhs);
    }
};
// 除算(ゼロ割チェック付き)
struct divide_t {
    template <typename T1, typename T2>
    auto operator()(T1&& lhs, T2&& rhs) const -> decltype(auto)
    {
        if (rhs == static_cast<remove_cv_reference_t<T2>>(0)) {
            throw std::invalid_argument("zero division");
        }
        return std::forward<T1>(lhs) / std::forward<T2>(rhs);
    }
};
// 演算子ファンクタリスト
using op_type_list_t = type_list_t<plus_t, minus_t, multiply_t, divide_t>;
// 演算子を表す Enum
enum class op_type_e : int8_t {
    PLUS = op_type_list_t::index_of<plus_t>(),
    MINUS = op_type_list_t::index_of<minus_t>(),
    MULTIPLY = op_type_list_t::index_of<multiply_t>(),
    DIVIDE = op_type_list_t::index_of<divide_t>()
};
// 演算子文字からEnum
static const std::unordered_map<std::string, op_type_e> g_char_to_enum{
    {"+", op_type_e::PLUS},
    {"-", op_type_e::MINUS},
    {"*", op_type_e::MULTIPLY},
    {"/", op_type_e::DIVIDE}};
// op_type_list_t::apply に渡す計算ファンクタ
struct calculator_t {
    int lhs{0};
    int rhs{0};
    template <typename Op>
    int operator()() const
    {
        return Op()(lhs, rhs);
    }
};

static int calculate(const std::string& op, int lhs, int rhs)
{
    op_type_e op_type = [op]() {
        auto iter = g_char_to_enum.find(op);
        if (iter == g_char_to_enum.end()) {
            throw std::invalid_argument("invalid operator");
        }
        return iter->second;
    }();
    return op_type_list_t::apply(calculator_t{lhs, rhs}, op_type);
}

この Advent Calendar は !(初心者) 向けなのでみなさん余裕だと思いますが, とりあえず下の方から見ていきましょう.

calculate 関数では, 演算子文字から enum 値に変換して, op_type_list_t::apply を呼びます. これで enum 値からテンプレート関数 calculator_t::operator()<T> が呼び出されます. T には op_type_list_t に登録された型が入ってきます.

type_list_t は, テンプレート引数に指定された型のリストを保持します.
コンパイル時関数 size(), index_of<T>(), at_t<I>boost::mpl::vector を使えば簡単に同じ効果を得られますが, 今回はコピペ一発で実行できることを優先して手書きしました.
で, apply が 3段階あります. まず型数からコンパイル時整数シーケンスを得ます. そのシーケンスを元に関数ポインタの配列を生成します. type_list_t<int, long, double> だったら {&apply<int>, &apply<long>, &apply<double>} となるわけです. 指定した enum 値を配列のインデックスとして, 3段目の apply が呼ばれます. 最後にテンプレート引数を指定してファンクタの operator()<T> が呼ばれます. 単純ですね.
テンプレート関数のポインタ配列は StackOverflow の記事 で紹介されています.

剰余算の実装

ここまでだとただめんどくさくなっただけなので, 機能を拡張して switch 除去の良さを確認しましょう.
演算子に剰余を追加します. やることは,
1. ファンクタを作成
2. enum 値を追加
3. 文字列->enum マップに追加
これだけです.

struct divide_t {
         return std::forward<T1>(lhs) / std::forward<T2>(rhs);
     }
 };
+// 剰余算(ゼロ割チェック付き)
+struct modulo_t {
+    template <typename T1, typename T2>
+    auto operator()(T1&& lhs, T2&& rhs) const -> decltype(lhs % rhs)
+    {
+        if (rhs == static_cast<remove_cv_reference_t<T2>>(0)) {
+            throw std::invalid_argument("zero division");
+        }
+        return std::forward<T1>(lhs) % std::forward<T2>(rhs);
+    }
+};
 // 演算子ファンクタリスト
-using op_type_list_t = type_list_t<plus_t, minus_t, multiply_t, divide_t>;
+using op_type_list_t =
+    type_list_t<plus_t, minus_t, multiply_t, divide_t, modulo_t>;
 // 演算子を表す Enum
 enum class op_type_e : int8_t {
     PLUS = op_type_list_t::index_of<plus_t>(),
     MINUS = op_type_list_t::index_of<minus_t>(),
     MULTIPLY = op_type_list_t::index_of<multiply_t>(),
-    DIVIDE = op_type_list_t::index_of<divide_t>()
+    DIVIDE = op_type_list_t::index_of<divide_t>(),
+    MODULO = op_type_list_t::index_of<modulo_t>()
 };
 // 演算子文字からEnum
 static const std::unordered_map<std::string, op_type_e> g_char_to_enum{
     {"+", op_type_e::PLUS},
     {"-", op_type_e::MINUS},
     {"*", op_type_e::MULTIPLY},
-    {"/", op_type_e::DIVIDE}};
+    {"/", op_type_e::DIVIDE},
+    {"%", op_type_e::MODULO}};
 // op_type_list_t::apply に渡す計算ファンクタ
 struct calculator_t {
     int lhs{0};

これだけです. calculate 関数は何も変更する必要がありません.

オペランドの型に double を追加

演算子を追加するだけだとやっぱり switch で十分じゃん, と思われるかもしれないので, double 型も使えるように拡張します.
まず int か double を持てる型 value_t を用意します.

// オペランドの型リスト
using value_type_list_t = type_list_t<int, double>;
enum class value_type_e : int8_t {
    INT = value_type_list_t::index_of<int>(),
    DOUBLE = value_type_list_t::index_of<double>()
};
// intかdoubleを持つ型
using value_t = boost::variant<int, double>;

calculate の引数に value_t を使うように変更します.

-static int calculate(const std::string& func, int lhs, int rhs)
+static value_t calculate(const std::string& func, value_t lhs, value_t rhs)

calculator_t も修正します.

 // op_type_list_t::apply に渡す計算ファンクタ
 struct calculator_t {
-    int lhs{0};
-    int rhs{0};
+    value_t lhs{0};
+    value_t rhs{0};
+
     template <typename Op>
-    int operator()() const
+    value_t operator()() const
+    {
+        switch (lhs.which()) {
+        case value_type_e::INT:
+            return operator()<Op, int>();
+        case value_type_e::DOUBLE:
+            return operator()<Op, double>();
+        }
+    }
+    template <typename Op, typename TL>
+    value_t operator()() const
+    {
+        switch (rhs.which()) {
+        case value_type_e::INT:
+            return operator()<Op, TL, int>();
+        case value_type_e::DOUBLE:
+            return operator()<Op, TL, double>();
+        }
+    }
+    template <typename Op, typename TL, typename TR>
+    value_t operator()() const
     {
-        return Op()(lhs, rhs);
+        return Op()(boost::get<TL>(lhs), boost::get<TR>(rhs));
     }
 };

普通に boost::apply_visitor 使えばいいんじゃ, という疑問は......まぁこの記事はサンプルということで.

このままだと double に演算子 % が定義されていないためコンパイルエラーとなります. SFINAE で処理を分岐させます.

+// nullptr_t を使った enabler_if
+// http://qiita.com/kazatsuyu/items/203584ef4cb8b9e52462
+template <bool pred>
+using enabler_if_t = typename std::enable_if<pred, std::nullptr_t>::type;

 struct modulo_t {
-    template <typename T1, typename T2>
-    auto operator()(T1&& lhs, T2&& rhs) const -> decltype(lhs % rhs)
+    // 両方とも整数ならこちら
+    template <typename T1,
+              typename T2,
+              enabler_if_t<std::is_integral<T1>{} && std::is_integral<T2>{}> =
+                  nullptr>
+    auto operator()(T1&& lhs, T2&& rhs) const -> decltype(auto)
     {
         if (rhs == static_cast<remove_cv_reference_t<T2>>(0)) {
             throw std::invalid_argument("zero division");
         }
         return std::forward<T1>(lhs) % std::forward<T2>(rhs);
     }
+    // どちらかが整数でない場合はこちら
+    template <typename T1,
+              typename T2,
+              enabler_if_t<!(std::is_integral<T1>{} &&
+                             std::is_integral<T2>{})> = nullptr>
+    auto operator()(T1&& lhs, T2&& rhs) const -> decltype(auto)
+    {
+        if (rhs == static_cast<remove_cv_reference_t<T2>>(0)) {
+            throw std::invalid_argument("zero division");
+        }
+        return std::fmod(std::forward<T1>(lhs), std::forward<T2>(rhs));
+    }
 };

これで double でも calculate を使えるようになりました.

Switch 文を除去, 再び

calculator_t の中にまた憎き switch が現れました. こんなのがあったら気軽に他の型を追加できないじゃないか!

滅びよ

 struct calculator_t {
     value_t lhs{0};
     value_t rhs{0};
-
+    // テンプレート引数を部分適用
+    template <typename T0, typename... Ts>
+    struct curried_t {
+        const calculator_t* self{nullptr};
+        template <typename T>
+        value_t operator()() const
+        {
+            assert(self);
+            return self->operator()<T0, Ts..., T>();
+        }
+    };
     template <typename Op>
     value_t operator()() const
     {
-        switch (lhs.which()) {
-        case value_type_e::INT:
-            return operator()<Op, int>();
-        case value_type_e::DOUBLE:
-            return operator()<Op, double>();
-        }
+        return value_type_list_t::apply(curried_t<Op>{this}, lhs.which());
     }
     template <typename Op, typename TL>
     value_t operator()() const
     {
-        switch (rhs.which()) {
-        case value_type_e::INT:
-            return operator()<Op, TL, int>();
-        case value_type_e::DOUBLE:
-            return operator()<Op, TL, double>();
-        }
+        return value_type_list_t::apply(curried_t<Op, TL>{this}, rhs.which());
     }
     template <typename Op, typename TL, typename TR>
     value_t operator()() const

これでよし. これで心置きなく std::string や std::complex などを追加できます.

まとめ

variadic template を使って関数ポインタテーブルを自動生成し, switch 文を除去してみました. type_list_t を使うクライアントコードがやることは,
1. type_list_t に振り分けるファンクタの型を登録
2. type_list_t::apply にロジックのファンクタと型IDを指定
のみです.

参考にしたもの

StackOverflow の記事
nullptr_t を使った enabler_if