LoginSignup
19
16

More than 5 years have passed since last update.

C++ で基底クラスの配列を適切に使う方法

Posted at

この記事では、C++ で派生クラスを基底クラスの配列に適切に保存する方法についてまとめます。
なお、全ての例は g++ 7.2.0 でオプションなしのコンパイルをして実行しました。
ただし、途中に出てくるPtrVectorクラスは Visual Studio 2017 上でも動作確認しています。

もし訂正や加筆などありましたらコメントください。

普通の配列とその問題点

まず、次のようなソースコードを考えます。(上手くいかない例)

#include <iostream>

class Base {
public:
    virtual void who() {
    std::cout << "Base" << std::endl;
    }
    virtual ~Base() = default;
};

class Derived1 : public Base {
public:
    void who() {
    std::cout << "Derived1" << std::endl;
    }
};

class Derived2 : public Base {
public:
    void who() {
    std::cout << "Derived2" << std::endl;
    }
};

void test_fail() {
    Base arr[2];
    Derived1 d1;
    Derived2 d2;
    arr[0] = d1;
    arr[1] = d2;
    d1.who();
    d2.who();
    arr[0].who();
    arr[1].who();
}

このソースコードでは単純に基底クラスBaseとその派生クラスDerived1Derived2を作り、test_fail関数内ではそれらをBaseの配列へ保存しています。そして、それらのデータがどちらの派生クラスなのかwho関数で調べています。
test_fail関数の実行結果は

Derived1
Derived2
Base
Base

となります。
このとき、代入文

    arr[0] = d1;

の左辺はBaseのオブジェクトであり、そこへDerived1のオブジェクトをコピーする際はただBaseの部分だけがコピーされます。その結果、arrへコピーした後のオブジェクトでwho関数を呼び出してもBaseと表示されてしまいます。

この問題を解決するために、3通りの解決法を紹介します。原理的には同じですが、後者の方がより最新のC++っぽいコードになります。

解決法1: ポインタの配列

まず最初に継承の基本に戻って以下の関数の挙動を確認します。

void test_derived() {
    Derived1 d1;
    Base obj = d1;
    Base* ptr = &d1;
    Base& ref = d1;
    obj.who();
    ptr->who();
    ref.who();
}

これを実行すると、

Base
Derived1
Derived1

と表示されます。つまり、ポインタと参照なら問題なく派生クラスのwhoを呼んでくれます。
そこで、ここではポインタの配列を使ってみます。

void test_ptrarray() {
    Base * arr[2];
    Derived1 d1;
    Derived2 d2;
    arr[0] = &d1;
    arr[1] = &d2;
    d1.who();
    d2.who();
    arr[0]->who();
    arr[1]->who();
}

これを実行すると、

Derived1
Derived2
Derived1
Derived2

のようにうまくいきました。arrにポインタを格納しているため、ポインタを通して派生クラスのwho関数を呼び出すことができています。

解決法2: スマートポインタとコンテナを使う (C++11)

続いて、前節のポインタの配列をスマートポインタのコンテナへ変更してみます。
スマートポインタ自体については
https://qiita.com/hmito/items/db3b14917120b285112f
などをご覧ください。

#include <memory>
#include <vector>

void test_sptrarray() {
    std::vector<std::shared_ptr<Base> > arr(2);
    auto d1 = std::make_shared<Derived1>();
    auto d2 = std::make_shared<Derived2>();
    arr[0] = d1;
    arr[1] = d2;
    d1->who();
    d2->who();
    arr[0]->who();
    arr[1]->who();
}

std::shared_ptrは派生クラスから基底クラスへ代入可能なように作られているため、普通のポインタと同じようにstd::shared_ptr<Derived1>std::shared_ptr<Base>へ代入できます。さらに、std::vectorstd::shared_ptrが自動でオブジェクトを破棄してくれるため、動的なメモリ確保でもメモリリークの心配がありません。

解決法3: コンテナを作る

最後に、前節を基にした自作のコンテナを紹介します。

  • 基本的にstd::vectorと同じような使い方ができるよう考えた API になっています。
  • 比較的新しい機能を色々使って書いているため、コンパイラが新しくないとコンパイルできないかもしれないことをご了承ください。
  • コメントは doxygen 用のスタイルで書いています。
  • 自由にコピーしたり書き換えたりして使って構いませんが、トラブルが発生しても責任は負いません。
PtrVector.h
/*! \file PtrVector.h
\author Kenta Kabashima
\date 2017/8/1

For vector of pointers
*/
#pragma once

#include <type_traits>
#include <memory>
#include <vector>

//! a vector for classes derived from one base classes
/*
a vector for classes derived from one base classes
Example:
\code{.cpp}
class Base {};

class Derived1 : public Base {};

class Derived2 : public Base {};

void test()
{
    PtrVector<Base> vec;
    vec.push_back(Derived1());
    vec.push_back(Derived2());
}
\endcode
*/
template<typename BaseType, typename Container = std::vector<std::shared_ptr<BaseType> > >
class PtrVector {
public:
    //! type of container
    using container_type = Container;
    //! type of size
    using size_type = typename Container::size_type;
    //! type of difference
    using difference_type = typename Container::difference_type;
    //! type of value
    using value_type = BaseType;
    //! type of value in container
    using inner_value_type = typename Container::value_type;

private:
    //! data
    Container _data;

public:
    //! default constructor (do nothing)
    PtrVector() noexcept : _data() {}

    //! add a data (STL-like function)
    /*!
    push a entry to the end of the vector on the condition that it is a instance of a class derived from BaseType.
    */
    template <typename Type>
    inline auto push_back(Type&& data)
    -> typename std::enable_if_t<std::is_base_of<BaseType, std::remove_const_t<std::remove_reference_t<Type> > >::value >
    {
        _data.push_back(std::make_shared<std::remove_const_t<std::remove_reference_t<Type> > >(std::forward<Type>(data)));
    }

    //! add a data with constructor
    /*!
    construct and push a entry to the end of the vector on the condition that it is a instance of a class derived from BaseType.
    */
    template <typename Type, typename... Args>
    inline auto emplace_back(Args&&... args)
    -> typename std::enable_if_t<std::is_base_of<BaseType, Type>::value >
    {
        _data.push_back(std::make_shared<Type>(std::forward<Args>(args)...));
    }

    //! i-th data
    BaseType& operator[](size_type i) { return *_data[i]; }
    //! i-th data
    const BaseType& operator[](size_type i) const { return *_data[i]; }

    //! reserve memory
    void reserve(size_type size) { _data.reserve(size); }

    //! clear memory
    void clear() { _data.clear(); }

    //! get size
    size_type size() const { return _data.size(); }
};

これをヘッダとして読み込んで、

#include "PtrVector.h"

void test_ptrvector() {
    PtrVector<Base> vec;
    vec.push_back(Derived1());
    vec.emplace_back<Derived2>();
    vec[0].who();
    vec[1].who();
}

のようなtest_ptrvector関数を実行すると、

Derived1
Derived2

のように望み通りの結果が得られます。

ソースコード

最後に、テスト用に作ったソースコード全体を載せておきます。

test.cpp
#include <iostream>

class Base {
public:
    virtual void who() {
    std::cout << "Base" << std::endl;
    }
    virtual ~Base() = default;
};

class Derived1 : public Base {
public:
    void who() {
    std::cout << "Derived1" << std::endl;
    }
};

class Derived2 : public Base {
public:
    void who() {
    std::cout << "Derived2" << std::endl;
    }
};

void test_fail() {
    Base arr[2];
    Derived1 d1;
    Derived2 d2;
    arr[0] = d1;
    arr[1] = d2;
    d1.who();
    d2.who();
    arr[0].who();
    arr[1].who();
}

void test_derived() {
    Derived1 d1;
    Base obj = d1;
    Base* ptr = &d1;
    Base& ref = d1;
    obj.who();
    ptr->who();
    ref.who();
}

void test_ptrarray() {
    Base * arr[2];
    Derived1 d1;
    Derived2 d2;
    arr[0] = &d1;
    arr[1] = &d2;
    d1.who();
    d2.who();
    arr[0]->who();
    arr[1]->who();
}

#include <memory>
#include <vector>

void test_sptrarray() {
    std::vector<std::shared_ptr<Base> > arr(2);
    auto d1 = std::make_shared<Derived1>();
    auto d2 = std::make_shared<Derived2>();
    arr[0] = d1;
    arr[1] = d2;
    d1->who();
    d2->who();
    arr[0]->who();
    arr[1]->who();
}

#include "PtrVector.h"

void test_ptrvector() {
    PtrVector<Base> vec;
    vec.push_back(Derived1());
    vec.emplace_back<Derived2>();
    vec[0].who();
    vec[1].who();
}

int main() {
    //test_fail();
    //test_derived();
    //test_ptrarray();
    //test_sptrarray();
    test_ptrvector();
}
19
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
16