LoginSignup
4
3

C++でパターンマッチ

Last updated at Posted at 2020-05-12

C++でパターンマッチ

C++でstd::variantに対するパターンマッチをする際にいちいちVisitorクラスを定義するのが嫌すぎたので自前でパターンマッチ(的なもの)を実装しました。既存の実装はマクロを使っているものがほとんどだったので今回はマクロを使わずテンプレートメタプログラミングのみで実装してみました。

ソースコードはここにあります。

実装

完成品とその利用例です

#include <type_traits>
#include <variant>
#include <optional>
#include <iostream>

namespace satch
{
    template <typename T, typename U, typename = std::void_t<>>
    struct is_comparable : std::false_type
    {
    };

    template <typename T, typename U>
    struct is_comparable<
        T, U, std::void_t<decltype((std::declval<T>() == std::declval<U>()))>>
        : std::true_type
    {
    };

    class _TypeBase{};

    class _Type{};

    template <typename T>
    class Type:public _TypeBase
    {
        public: 
        using match_type = T;
        using __type_id = _Type;
        Type() = default;
    };

    class _Value{};

    template <typename T>
    class Value
    {
        public: 
        using match_type = T;
        using __type_id = _Value;
        private:
        T pattern;

        public:
        Value()=delete;
        Value(T value) : pattern(value){}
        const T& get_pattern() const { return pattern; }
    };

    class Default
    {
        public:
        Default() = default;
    };

    template <typename Type1, typename... Types>
    constexpr bool has_default_case()
    {
        // end of checking
        if constexpr (std::is_same<Default, Type1>::value){return true;}
        else 
        {
            //continue checking
            constexpr bool result = has_default_case<Types...>();
            return result;
        }
        return false;
    }

    template <typename VariantType>
    class Match
    {
        private:
        VariantType& variant;

        template <typename Pattern, typename Function, typename... Rest>
        auto get_result_type()
        {
            if constexpr (std::is_same<Pattern,Default>::value)
            {
                using result_type = decltype(std::declval<Function>()(std::declval<VariantType>()));
                return result_type();
            }
          
            else
            {
                using match_type = typename Pattern::match_type;
                using result_type = decltype(std::declval<Function>()(std::declval<match_type>()));
                return result_type();
            }
        }

        template <int Index, typename ResultType, typename CaseType,
                typename FunctionType, typename... Rest>
        ResultType match(CaseType& case_obj,
                FunctionType& function, Rest&... rest)
        {
            if constexpr (std::is_same<CaseType,Default>::value)
            {
                static_assert(sizeof...(Rest) == 0,"Default case need to be passed at the end of arguments.");
                return function(variant);
            }
            else
            {
                using match_type = typename std::remove_reference<CaseType>::type::match_type;

                if (std::holds_alternative<match_type>(variant))
                {
                    //if pattern has value
                    if constexpr(std::is_same<typename CaseType::__type_id,_Value>::value)
                    {              
                        static_assert(is_comparable<match_type, match_type>::value,"cannnot compare pattern and value due to its type");
                                
                        if (std::get<match_type>(variant) == case_obj.get_pattern())
                        {
                            return function(std::get<match_type>(variant));
                        }
                    }
                    else
                    {
                        return function(std::get<match_type>(variant));
                    }
                }
                if constexpr (Index == 0)
                {
                    if constexpr (has_default_case<CaseType,Rest...>())
                    {}//OK.
                    else
                    {
                        constexpr bool not_instantiated = []{return false;}();
                        static_assert(not_instantiated,"The number of patterns must be greater than the number of the variant's tags unless 'Default' case is passed.");
                    }
                }

                if constexpr (sizeof...(Rest) != 0)
                {
                    return match<Index + 1,ResultType>(rest...);
                }
            }
        }

        public:
        Match(VariantType& variant) : variant(variant) {}

        template <typename... Args>
        auto operator()(Args&&... args)
        {
            using result_type = decltype(get_result_type<Args...>());
            return match<0,result_type>( args...);        
        }
    };
} 


int main()
{
    auto variant = std::variant<int, std::string, double, float>(10);
    std::string input;
    std::cout << "please input text: ";
    std::cin >> input;
    variant = input;

    auto result = satch::Match{variant}(
        satch::Type<std::string>(), [](auto&& str) 
            {
                std::cout << "variant contains string value: " << str << std::endl;
                return 1; 
            },
        satch::Value<int>(10), [](auto&& value)
            {
                std::cout << "variant contains int value 10" << val << std::endl;
                return 2;
            },
        satch::Default(),[](auto&& variant) 
            {
                std::cout << "variant.index() = " << variant.index() << std::endl;
                return 3;
            }
        );
    
    std::cout << "matching returned: " << result << std::endl;
    
    return 0;
}

こんな感じで、Matchクラスがコンストラクタでvariant型の値を受け取り、operator()でパターンと関数を受け取っています。ここで重要なのは引数の順番で、Matchクラスはoperator()で「パターン、関数、パターン、関数...」の順番で引数を取ることを前提に実装されています。また、パターンにはType、Value、Defaultの三種類があり、それぞれ型の一致、値の一致、何も一致しなかった場合を表します。

では、実装の流れを追ってみましょう。

template <typename T, typename U, typename = std::void_t<>>
struct is_comparable : std::false_type
{
};

template <typename T, typename U>
struct is_comparable<
T, U, std::void_t<decltype((std::declval<T>() == std::declval<U>()))>>
    : std::true_type
{
};

まずは冒頭部のis_comparableという構造体について。

この構造体はstd::void_tというメタ関数を利用してoperator()==が適用できるかを調べることで二つの型が比較可能か否かを判別しています。これはパターンにValueクラスのインスタンスが渡された際、そもそも値が比較可能かを調べ、比較不能だった際にわかりやすいコンパイルエラーを出すために定義されています。実装側でチェックしなくてもコンパイルエラーはきちんと出されますが、C++のテンプレートのエラーは非常にわかりづらいので自前でこのようなチェックを挟んでいます。

続くhas_default_case関数は、引数リストの中にDefaultケースが存在するかを調べる関数です。

続いて、Matchクラスのメンバ関数の実装を見ていきます。

template <typename Pattern, typename Function, typename... Rest>
auto get_result_type()
{
		if constexpr (std::is_same<Pattern,Default>::value)
		{
				using result_type = decltype(std::declval<Function>()(std::declval<VariantType>()));
				return result_type();
    }
		else
		{
				using match_type = typename Pattern::match_type;
				using result_type = decltype(std::declval<Function>()(std::declval<match_type>()));
				return result_type();
    }
}

最初のメンバ関数get_result_typeは、渡された引数から関数を適用した際の戻り値を取得する関数です。
パターンがDefaultか否かで処理を分けているのは、Defaultの時のみ渡された関数に引数としてvariantオブジェクトをそのまま渡すからです。また、戻り値の型の取得には、declvalという「ある型の値だが実際に評価されることはない」ものを返す関数を利用しています。型だけ指定されている架空の値を返す関数といったところでしょうか。
これを利用することで「型だけ指定されている架空の関数オブジェクトに型だけ指定されている架空の値を渡した時に返り値の型はどうなるか」を取得しています。ややこしいですね......
C++でのメタプログラミングは複雑になりがちなので細かく関数に切り分けることが大切になってきます。

次のメンバ関数matchは、コードが長い割にやっていることは単調な場合分けなので解説は省略します。

コンストラクタについてもvariantの参照を受け取っているだけなので特に解説する箇所はありません。

最後のメンバ関数operator()もコードとしては短いです。


template <typename... Args>
auto operator()(Args&&... args)
{
    using result_type = decltype(get_result_type<Args...>());
		return match<0,result_type>( args...);        
}

この関数では前述のget_result_typeで戻り値の型を取得し、match関数に引数とともに渡しています。match関数のテンプレートパラメーターの一つ目に0を指定しているのは、match関数内部で可変長引数に対応するため再起呼び出しを行なっており、何回めの再起呼び出しかを把握することでパターンの記述漏れ等をチェックしているため最初の呼び出しで0を指定する必要があるからです。

解説は以上となります。大分複雑なコードですが、順を追って見ていけば何となく理解できてくるかもしれません。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3