LoginSignup
11
6

それって本当にdynamic_castじゃないとダメ? static_castにして高速化

Last updated at Posted at 2023-12-16

はじめに

ダウンキャストやクロスキャストは使っていますでしょうか?

そもそもダウンキャストしないに越したことはないのですが、それでも使っている人もいると思います。
(自分もdynamic_castが大量みたいな現場を経験したことがある)

さて、C++のいろいろあるキャストの中でもdynamic_castは、 実行時型情報 RTTI を使ってキャストするので安全なダウンキャストが可能です。

dynamic_caststatic_castの違いなどは以下の記事でも良くまとまっています。

#include <iostream>

struct Super
{
    virtual ~Super()=default;
};
struct A : Super
{
};
struct B : Super
{
};

int main()
{
    Super* s = new A();
    
    // インスタンスはAなので nullptr が返る
    B* b0 = dynamic_cast<B*>(s);
    if (b0) {
        std::cout << "s is B" << std::endl;
    }
    // 強制的にB*にキャストする nullptr ではない
    B* b1 = static_cast<B*>(s);
    if (b1) {
        std::cout << "s is B ???" << std::endl;
    }
    delete s;
}

しかしstatic_castに比べて速度は遅いのがデメリットになります。

実行時のインスタンスの型が何かしらで保証されていてRTTIに頼る必要がないならばstatic_castを使ったほうがパフォーマンスとしては良いです。

なんでもかんでもdynamic_castにせず、static_castにできる所は変更していくことで
ダウンキャスト(及びクロスキャスト)を高速化してみましょう。

仕様でそうなってるんだからヨシ!

プログラマ的には不安な点が多いですが、プログラム上は基底クラスでやり取りしていても、仕様上絶対に実体インスタンスの型が決まっているなら、dynamic_cast じゃなくてもいいかもしれません。

void hoge(Super* s)
{
    // 仕様的に、この関数が呼ばれる時パラメータの実体は絶対にB型なので、ヨシ! (???)
    B* b = static_cast<B*>(s);
    // Something
}

これも一つの高速化の方法としてはあります。

う~ん、怖いなぁ…

以降では、もう少しだけ安全さを担保しつつstatic_castにしていく方法を考えていきましょう。

基底クラスから明示的に型情報をとれるようにしておく

enum class InstanceKind
{
    Super,
    A,
    B,
};
struct Super
{
    virtual ~Super()=default;
    virtual InstanceKind kind() const
    {
        return InstanceKind::Super;
    }
};
struct A : Super
{
    InstanceKind kind() const override
    {
        return InstanceKind::A;
    }
};
struct B : Super
{
    InstanceKind kind() const override
    {
        return InstanceKind::B;
    }
};

サブクラスの型の種類を事前に決めうちできるようなケースならば、
たとえば上記のようにenum値を持たせておけば、インスタンスの型を基底クラスのポインタからも調べることが可能なので
dynamic_castを使わずとも自分で型チェックができそうです。

template<class ToPtr>
ToPtr my_dynamic_cast(Super* super)
{
    if (!super) {
        return nullptr;    
    }
    if (super->kind() == InstanceKind::Super) {
        if constexpr (std::convertible_to<Super*, ToPtr>) {
            return static_cast<ToPtr>(super);
        }
    } else if (super->kind() == InstanceKind::A) {
        if constexpr (std::convertible_to<A*, ToPtr>) {
            return static_cast<ToPtr>(static_cast<A*>(super));
        }
    } else if (super->kind() == InstanceKind::B) {
        if constexpr (std::convertible_to<B*, ToPtr>) {
            return static_cast<ToPtr>(static_cast<B*>(super));
        }
    }
    return nullptr;
}
B* b = my_dynamic_cast<B*>(s);

こんな感じの自前キャスト関数を用意してみました。

struct A : Super , Interface {};
struct B : Super {};

static_assert(std::convertible_to<A*, Super*>);
static_assert(std::convertible_to<A*, Interface*>);
static_assert(std::convertible_to<B*, Super*>);
static_assert(!std::convertible_to<B*, Interface*>);

std::convertible_toを使ってコンパイル時条件分岐しているのは、クロスキャストへの対応のためです。
キャストできない場合はnullptrを返します。

C++17以前であれば std::is_convertible_vを使いましょう。

enumが増えるたびに調整しないといけないので、メンテナンスが大変というデメリットはありますが、static_cast主体で、dynamic_castのようにキャストが不適切な場合nullptrを返すことができました。

ベンチマークもとってみました。
https://quick-bench.com/q/ec8rYLYN6v7cIYirDOmQ80EgQKs

image.png

最も単純な対処方法ですが、dynamic_castを使うよりかなり高速化されていそうです。

注意点として、サブクラスが仮想継承になる場合はstatic_castでダウンキャストができないのでdynamic_castを使うしかなくなります。

struct A : virtual Super{};
int main()
{
    Super* s = new A();
    A* a = static_cast<A*>(s); // compile error
    if (a) {
        std::cout << "s is A" << std::endl;
    }
    delete s;
}

ラッパークラスを作って元の型情報を持たせておく

個人的に、ちょっと遊びで作ったものですが、いくつかポインタのラッパークラスを作ってみました。
実用的ではないかもしれませんが、テクニックの参考程度にせっかくなので紹介します。

variantを使ったポインタラッパー

variant_ptr

前述の方法のように基底クラスから型情報をとれるようにしなくても
似たようなことができるポインタのラッパークラスを作りました。
こちらもインスタンスの実体になりうるサブクラスが事前に決め打ちできている状態なら使用できます。

struct Super
{
    virtual ~Super()=default;
};
struct Interface
{
    virtual ~Interface()=default;
};
struct A : Super , Interface {};
struct B : Super {};

int main()
{
    variant_ptr<Super, A, B> s{new A()}; // 初期化時だけdynamic_cast同等に重い処理になる
    B* b = variant_dynamic_cast<B*>(s);  // それ以降は何度ダウンキャストしても高速に判定できる
    if (b) {
        std::cout << "s is B" << std::endl;
    }
    Interface* i = variant_dynamic_cast<Interface*>(s);  // クロスキャスト
    if (i) {
        std::cout << "s is Interface" << std::endl;
    }    
    delete s.get();
}

特徴として、初期化時だけ重い、それ以降のダウンキャストは高速に行えます。

初期化時に、元のインスタンスの実体の型を調べるので重い。
型チェックを不要にすれば、初期化時も重くない

実装詳細
variant_ptr.hpp
#pragma once
#include <concepts>
#include <variant>

template<class Base, std::derived_from<Base>... Deriveds>
class variant_ptr
{
public:
    using value_type = std::variant<Base*, Deriveds*...>;
public:
    variant_ptr() noexcept = default;
    template<class Type> requires std::same_as<Type, Base> || (std::same_as<Type, Deriveds> || ...)
        variant_ptr(Type * ptr) noexcept
    {
        *this = ptr;
    }
    variant_ptr(const variant_ptr& other) noexcept :
        m_ptr(other.m_ptr)
    {}
    variant_ptr(std::nullptr_t) noexcept :
        m_ptr()
    {}
    Base* get() const noexcept
    {
        return cast_to<Base*>();
    }
    Base* operator ->() const noexcept
    {
        return get();
    }
    Base& operator *() const noexcept
    {
        return *get();
    }
    template<class ToPtr> requires std::is_pointer_v<ToPtr> && (std::convertible_to<Base*, ToPtr> || (std::convertible_to<Deriveds*, ToPtr> || ...))
        ToPtr cast_to() const noexcept
    {
        return std::visit([]<class T>(T * ptr)->ToPtr {
            if constexpr (std::convertible_to<T*, ToPtr>) {
                return static_cast<ToPtr>(ptr);
            } else {
                return nullptr;
            }
        }, m_ptr);
    }
    void reset()
    {
        m_ptr = value_type{};
    }
    explicit operator bool() const noexcept
    {
        return get() != nullptr;
    }
    template<class Type> requires std::same_as<Type, Base> || (std::same_as<Type, Deriveds> || ...)
    variant_ptr& operator=(Type* ptr) noexcept
    {
        if (ptr == nullptr) {
            m_ptr = value_type{};
            return *this;
        }
        const auto& rtti = typeid(*ptr);
        if (rtti == typeid(Type)) [[likely]] {
            m_ptr = ptr;
            return *this;
            }
        Base* base = static_cast<Base*>(ptr);
        if (rtti == typeid(Base)) {
            m_ptr = base;
            return *this;
        }
        auto derived_checks = [&]<class T>(T*) {
            if (rtti == typeid(T)) {
                if constexpr (std::convertible_to<Base*, T*>) {
                    m_ptr = static_cast<T*>(base);
                    return true;
                } else {
                    m_ptr = dynamic_cast<T*>(base);
                    return true;
                }
            }
            return false;
        };
        if ((derived_checks(static_cast<Deriveds*>(nullptr)) || ...)) {
            return *this;
        }
        m_ptr = ptr;
        return *this;
    }
    variant_ptr& operator=(const variant_ptr& other) noexcept
    {
        m_ptr = other.m_ptr;
        return *this;
    }
private:
    value_type m_ptr;
};
template<class ToPtr>
struct variant_dynamic_cast_impl
{
    template<class Base, std::derived_from<Base>... Deriveds>
        requires std::is_pointer_v<ToPtr> && (std::convertible_to<Base*, ToPtr> || (std::convertible_to<Deriveds*, ToPtr> || ...))
    ToPtr operator()(const variant_ptr<Base, Deriveds...>& ptr) const noexcept
    {
        return ptr.template cast_to<ToPtr>();
    }
};
template<class ToPtr>
inline constexpr auto variant_dynamic_cast = variant_dynamic_cast_impl<ToPtr>{};

あらかじめサブクラスの型を明示して、各ポインタのvariantを作っておくことで
型情報を残すようにしています。

std::variant<Base*, Deriveds*...>

初期化時だけ安全重視でtypeidを使ってRTTIチェックしています。

    template<class Type> requires std::same_as<Type, Base> || (std::same_as<Type, Deriveds> || ...)
    variant_ptr& operator=(Type* ptr) noexcept
    {
        if (ptr == nullptr) {
            m_ptr = value_type{};
            return *this;
        }
        const auto& rtti = typeid(*ptr);
        if (rtti == typeid(Type)) [[likely]] {
            m_ptr = ptr;
            return *this;
            }
        Base* base = static_cast<Base*>(ptr);
        if (rtti == typeid(Base)) {
            m_ptr = base;
            return *this;
        }
        auto derived_checks = [&]<class T>(T*) {
            if (rtti == typeid(T)) {
                if constexpr (std::convertible_to<Base*, T*>) {
                    m_ptr = static_cast<T*>(base);
                    return true;
                } else {
                    m_ptr = dynamic_cast<T*>(base);
                    return true;
                }
            }
            return false;
        };
        if ((derived_checks(static_cast<Deriveds*>(nullptr)) || ...)) {
            return *this;
        }
        m_ptr = ptr;
        return *this;
    }

これは以下のようなコードを書かれた場合にも元の型を調べるためです。

    Super* s = new B();
    variant_ptr<Super, A, B> vs{s}; // Super*が引数で渡されてるけど実体はB
    B* b = variant_dynamic_cast<B*>(vs);

上記のようなケースを考慮せず、コンストラクタ(代入)の引数の型を基準にしてよければ、以下のように修正するのが良い。

    template<class Type> requires std::same_as<Type, Base> || (std::same_as<Type, Deriveds> || ...)
    variant_ptr& operator=(Type* ptr) noexcept
    {
        m_ptr = ptr;
        return *this;
    }

visitを使って元の型からキャストを試すので
安全にキャストが可能です。
こちらも適切なキャストでない場合は nullptr になります。

        return std::visit([]<class T>(T * ptr)->ToPtr {
            if constexpr (std::convertible_to<T*, ToPtr>) {
                return static_cast<ToPtr>(ptr);
            } else {
                return nullptr;
            }
        }, m_ptr);

型消去とキャスト演算子のオーバーライドを使ったポインタラッパー

fixed_ptr

こちらは、インスタンスの実体の型は事前に決め打ちできずともキャスト先の型が事前にわかるようなケースで使えるラッパーです。
内部でshared_ptrを使っていますので、スマートポインタのラッパーになります。

struct Super
{
    virtual ~Super()=default;
};
struct Interface
{
    virtual ~Interface()=default;
};
struct A : Super , Interface {};
struct B : Super {};

int main()
{
    fixed_ptr<Super, Interface, B> s{new A()};
    B* b = fixed_dynamic_cast<B*>(s);  // ダウンキャスト
    if (b) {
        std::cout << "s is B" << std::endl;
    }

    Interface* i = fixed_dynamic_cast<Interface*>(s);  // クロスキャスト
    if (i) {
        std::cout << "s is Interface" << std::endl;
    }    
}

こちらはvariant_ptrと違い、初期化時の実体チェックが難しいので
コンストラクタに渡された引数の型をベースに判定しており
実体と誤なる型で引数にした場合は正しく判定できない事に注意

    Super* base = new B();
    fixed_ptr<Super, Interface, B> s{base}; // 引数の型がSuper*になってる
    B* b = fixed_dynamic_cast<B*>(s);  // nullptr
実装詳細
fixed_ptr.hpp
#pragma once
#include <concepts>
#include <variant>
#include <memory>

template<class Base, class... Types>
class fixed_ptr
{
private:
    template<class T>
    struct icaster
    {
        virtual ~icaster() = default;
        virtual operator T*() const = 0;
    };
    struct base_type : icaster<Types>...
    {
        using icaster<Types>::operator Types*...;
        virtual ~base_type() = default;
        virtual Base* get() const = 0;
    };
    template<class T, class Head, class... Tail>
    struct holder : holder<T, Tail...>
    {
        using holder<T, Tail...>::holder;
        using holder<T, Tail...>::ptr;
        using holder<T, Tail...>::operator Tail*...;
        operator Head* () const override
        {
            if constexpr (std::convertible_to<T*, Head*>) {
                return static_cast<Head*>(ptr);
            } else {
                return nullptr;
            }
        }
    };
    template<class T, class Last>
    struct holder<T, Last> : base_type
    {
        T* ptr;
        holder(T* _ptr):
            ptr(_ptr)
        {}
        ~holder()
        {
            delete ptr;
        }
        Base* get() const override
        {
            return ptr;
        }
        operator Last* () const override
        {
            if constexpr (std::convertible_to<T*, Last*>) {
                return static_cast<Last*>(ptr);
            } else {
                return nullptr;
            }
        }
    };
public:
    fixed_ptr() = default;
    template<class Type>
    fixed_ptr(Type* ptr) requires std::derived_from<Type, Base> :
        m_ptr(std::make_shared<holder<Type, Types...>>(ptr))
    {}
    fixed_ptr(std::nullptr_t) :
        m_ptr(nullptr)
    {}
    Base* get() const
    {
        if (!m_ptr) {
            return nullptr;
        }
        return m_ptr->get();
    }
    Base* operator ->() const
    {
        return get();
    }
    Base& operator *() const
    {
        return *get();
    }    
    template<class ToPtr>
        requires (std::same_as<std::decay_t<ToPtr>, Types*> || ...)
    ToPtr cast_to() const
    {
        if (!m_ptr) {
            return nullptr;
        }
        return static_cast<ToPtr>(*m_ptr);
    }
    void reset()
    {
        m_ptr = nullptr;
    }
    explicit operator bool() const
    {
        return get() != nullptr;
    }
    template<class T, class... Args>
    static fixed_ptr make_fixed(Args&&... args)
    {
        return fixed_ptr(new T(std::forward<Args>(args)...));
    }
private:
    std::shared_ptr<base_type> m_ptr;
};
template<class ToPtr>
struct fixed_dynamic_cast_impl
{
    template<class Base, class... Types>
        requires (std::same_as<std::decay_t<ToPtr>, Types*> || ...)
    ToPtr operator()(const fixed_ptr<Base, Types...>& ptr) const
    {
        return ptr.template cast_to<ToPtr>();
    }
};
template<class ToPtr>
inline constexpr auto fixed_dynamic_cast = fixed_dynamic_cast_impl<ToPtr>{};

型消去で元のポインタ型の情報を保持しつつ
キャスト先になりうる型へのキャスト演算子を全部overrideすることで
static_castベースでのダウンキャストを実現します。

    template<class T, class Head, class... Tail>
    struct holder : holder<T, Tail...>
    {
        using holder<T, Tail...>::holder;
        using holder<T, Tail...>::ptr;
        using holder<T, Tail...>::operator Tail*...;
        operator Head* () const override
        {
            if constexpr (std::convertible_to<T*, Head*>) {
                return static_cast<Head*>(ptr);
            } else {
                return nullptr;
            }
        }
    };
    
    template<class T, class Last>
    struct holder<T, Last> : base_type
    {
        T* ptr;
        holder(T* _ptr):
            ptr(_ptr)
        {}
        ~holder()
        {
            delete ptr;
        }
        Base* get() const override
        {
            return ptr;
        }
        operator Last* () const override
        {
            if constexpr (std::convertible_to<T*, Last*>) {
                return static_cast<Last*>(ptr);
            } else {
                return nullptr;
            }
        }
    };

関連

本題とはずれますが、気を付けたいところ。

ダウンキャストをキャッシュする

複数回、同じ型にダウンキャストするのであれば、きちんとキャッシュしましょう。
dynamic_castしか使えなかったとしても、こういった細かい気遣いが効いてくるかもしれません。

int main()
{
    Super* s = new A();
    if (dynamic_cast<A*>(s)) {
        std::cout << "s is A" << std::endl;
    }
    if (dynamic_cast<A*>(s)) {
        std::cout << "s is A" << std::endl;
    }
}
int main()
{
    Super* s = new A();
    A* a = dynamic_cast<A*>(s);
    if (a) {
        std::cout << "s is A" << std::endl;
    }
    if (a) {
        std::cout << "s is A" << std::endl;
    }
}

まとめ

https://quick-bench.com/q/Qo8jA5Ula132ehRXZS049m4uwlY
image.png

ラッパークラスを作ったりとかいろいろ試しながら
dynamic_castをなんとかstatic_castに変換しつつも、ある程度安全なキャストを実現する方法をいろいろ考えてみました。

dynamic_castは意外と重いので状況次第ではstatic_castに変換することも検討してみてはどうでしょうか?

そもそも、ダウンキャストやクロスキャストが無いのが好ましいと言われたら…そうなんだろうが

11
6
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
6