私が普段よくお世話になっている行列計算のライブラリ Eigen では式テンプレート(Expression template)が高速化に利用されているそうです。しかし、実際のところどれくらい効果があるのか数字で見たことはなかったため、今回自分で実験してみました。
本文中のソースコードは今回実験するのに必要最低限の機能だけ実装した適当なものですので、実用に使いたい人は Eigen を使いましょう。(参考:以前書いた Eigen の密行列関連の記事)
式テンプレートについて
ここでは、簡単のためベクトルの足し算で式テンプレートを簡単に紹介します。
そこで、先に式テンプレートを使わない簡単な実装とその問題点を紹介し、そのあと式テンプレートを用いた実装を紹介し、2つの実装を比較していきます。
簡単な実装とその問題点
まず、次のクラスを作ってみました。
#pragma once
#include <vector>
namespace Temp {
class Vector {
private:
std::vector<double> _data;
public:
Vector() : _data() {}
Vector(std::size_t size) : _data(size) {}
Vector(const Vector& right) : _data(right._data) {}
Vector& operator=(const Vector& right) {
_data = right._data;
return *this;
}
double& operator[](std::size_t i) { return _data[i]; }
double operator[](std::size_t i) const { return _data[i]; }
Vector& operator+=(const Vector& right) {
std::size_t size = _data.size();
for (std::size_t i = 0; i < size; ++i) {
_data[i] += right._data[i];
}
return *this;
}
};
Vector operator+(const Vector& left, const Vector& right) {
return Vector(left) += right;
}
}
実験に必要な機能のみのシンプルなクラスです。operator+
の実装には More Effective C++ を参考にしました。
名前空間 Temp は後述のようにこの Vector クラスで operator+
を用いると一時オブジェクトが生成されることからつけています。(後で違う Vector クラスを作成するので名前空間を分けました。)
この Vector クラスのオブジェクト vec1, vec2, vec3 を使って
vec3 = vec1 + vec2;
を実行すると、
-
operator+(vec1, vec2)
が呼び出される。 -
Vector(vec1)
で vec1 のコピーを行う(これを temp_vec とする)。 - temp_vec に
operator+=
でvec2
が足される。 -
operator+(vec1, vec2)
の戻り値として temp_vec が出力される。 - temp_vec を vec3 に代入する。
のように処理が進みます。このとき、一時オブジェクト temp_vec に一旦 vec1 + vec2
の結果を保存してそれを vec3 にコピーするのですが、最初から vec3 に vec1 + vec2
の結果を保存した方が速くなりそうです。
ちなみに、この問題を比較的簡単に解決する方法として、
vec3 = vec1;
vec3 += vec2;
の順に実行するというのがあり、簡単に自作クラスを作って使いたいのであればこれが良いと思います。その場合、思い切って operator+
は実装せず必ず operator+=
を使うようにするというのも1つの手かもしれません。
式テンプレートによる実装例
一時オブジェクトの問題を式テンプレートで解消する例として、次のクラスを作ってみました。
#pragma once
#include <vector>
namespace Expr {
template <typename Derived>
class VectorBase {
public:
Derived& derived() { return *static_cast<Derived* const>(this); }
const Derived& derived() const { return *static_cast<const Derived* const>(this); }
double operator[](std::size_t i) const { return derived().operator[](i); }
std::size_t size() const { return derived().size(); }
};
class Vector : public VectorBase<Vector> {
private:
std::vector<double> _data;
public:
Vector() : _data() {}
Vector(std::size_t size) : _data(size) {}
double& operator[](std::size_t i) { return _data[i]; }
double operator[](std::size_t i) const { return _data[i]; }
std::size_t size() const { return _data.size(); }
template <typename RightType>
Vector(const VectorBase<RightType>& right)
: _data()
{
const RightType& right_derived = right.derived();
std::size_t size = right_derived.size();
_data.resize(size);
for (std::size_t i = 0; i < size; ++i) {
_data[i] = right_derived[i];
}
}
template <typename RightType>
Vector& operator=(const VectorBase<RightType>& right) {
const RightType& right_derived = right.derived();
std::size_t size = right_derived.size();
_data.resize(size);
for (std::size_t i = 0; i < size; ++i) {
_data[i] = right_derived[i];
}
return *this;
}
};
template <typename LeftType, typename RightType>
class Add : public VectorBase<Add<LeftType, RightType> > {
private:
const LeftType& _left;
const RightType& _right;
public:
Add(const LeftType& left, const RightType& right)
: _left(left), _right(right)
{}
std::size_t size() const { return _left.size(); }
double operator[](std::size_t i) const { return _left[i] + _right[i]; }
};
template <typename LeftType, typename RightType>
Add<LeftType, RightType> operator+(const VectorBase<LeftType>& left, const VectorBase<RightType>& right) {
return Add<LeftType, RightType>(left.derived(), right.derived());
}
}
Eigen を参考に実装しました。式テンプレートを使ったのに加え、あとで3つのベクトルを足し算する例を紹介したいため CRTP (Curriously Recursing Template Pattern) による基底クラス VectorBase を作り、VectorBase に対して足し算を定義するという実装にしました。
この Vector クラスのオブジェクト vec1, vec2, vec3 を使って vec3 = vec1 + vec2
を実行すると、
-
vec1 + vec2
を表現するAdd<Vector, Vector>(vec1, vec2)
というオブジェクトを生成。(※具体的な計算はしていないことに注意!) -
Add<Vector, Vector>(vec1, vec2)
を vec3 に代入する。このとき、Add<Vector, Vector>(vec1, vec2)
のoperator[](i)
が呼ばれ、そこでvec1[i] + vec2[i]
を要素ごとに計算する。
のように計算が行われます。このとき、Add<Vector, Vector>(vec1, vec2)
という一時オブジェクトが生成されますが、この一時オブジェクトは vec1 と vec2 への参照を持っているだけのため、比較的簡単に生成・破棄できます。また、vec1 + vec2
の結果を直接 vec3 に代入できています。
3つのベクトルを足す場合の比較
ここまで、2つのベクトルの足し算で一時オブジェクトが生成されるかどうかの比較をしていましたが、ここでは3つのベクトルを足す場合の比較を通して式テンプレートの良さを紹介します。
vec4 = vec1 + vec2 + vec3
を実行すると、Temp::Vector
では
Vector temp_vec1(vec1);
std::size_t size1 = temp_vec1.size();
for (std::size_t i = 0; i < size1; ++i) {
temp_vec1[i] += vec2[i];
}
Vector temp_vec2(temp_vec1);
std::size_t size2 = temp_vec2.size();
for (std::size_t i = 0; i < size2; ++i) {
temp_vec2[i] += vec3[i];
}
vec4 = temp_vec2;
のような動作をします。2つも一時オブジェクトが生成されるのでベクトルのサイズによってはなかなかのオーバーヘッドになるかもしれません。しかし、Expr::Vector
では、
std::size_t size = vec1.size();
for (std::size_t i = 0; i < size; ++i) {
vec4[i] = vec1[i] + vec2[i] + vec3[i];
}
と動作します。これを実現するために vec1 + vec2
を表す Add<Vector, Vector>(vec1, vec2)
(これを temp_vec1 とする)と、そこに vec3 を加えた Add<Add<Vector, Vector>, Vector>(temp_vec1, vec3)
が生成されますが、これらは足すベクトルの参照を持っているだけなため生成・破棄のオーバーヘッドは比較的小さいと考えられます。
このように一括で3つのベクトルを1つのベクトルへ足すことができるというのは式テンプレートの便利な点の1つと言えます。なお、Eigen では式テンプレートを用いてより速い計算法で計算ができるようにプログラムを書いているそうです(Eigen のドキュメンテーションを参照)。
実験
ここで、一時オブジェクト上に足し算の結果を保存する Temp::Vector
と式テンプレートを用いた Expr::Vector
で vec3 = vec1 + vec2
の時間と vec4 = vec1 + vec2 + vec3
の時間を競争してみました。
条件
Intel Core i7-7700HQ の入ったノートパソコンを使用し、Windows 10 上の Visual Studio 2017 Community で Release モードで後述のソースコードをコンパイルし、実行しました。
結果
上のグラフは100回計算を行った平均と標準誤差をグラフにしたものです。凡例の Temp と Expr は Temp::Vector
と Expr::Vector
を表し、そのあとの数字は足したベクトルの数を表します。
ベクトルのサイズが小さくても大きくても式テンプレートを使った Expr の方が速いことが分かります。Temp / Expr の比を取った次のグラフではよりそれが分かりやすいと思います。
エラーバーは誤差の伝搬の公式に従って標準誤差を伝搬したものです。
これで式テンプレートがベクトルの足し算において実際に速いことを確認できました。
ソースコード
本文中の TempVec.h と ExprVec.h を使って次のプログラムを作り、実験に使いました。
#include <iostream>
#include <iomanip>
#include <chrono>
#include <random>
#include "TempVec.h"
#include "ExprVec.h"
namespace Test {
template <typename VecType>
void test_add2(std::size_t size) {
std::cout << "Test of addition of 2 vectors" << "\n";
std::cout << " Vector Type: " << typeid(VecType).name() << "\n";
std::cout << " Vector Size: " << size << "\n";
std::cout << " Wait..." << std::endl;
VecType vec1(size), vec2(size);
std::mt19937 rand_eng;
std::uniform_real_distribution<double> rand_dist;
for (std::size_t i = 0; i < size; ++i) {
vec1[i] = rand_dist(rand_eng);
vec2[i] = rand_dist(rand_eng);
}
double sum = 0.0;
double sum2 = 0.0;
constexpr std::size_t rep = 100;
for (std::size_t i = 0; i < rep; ++i) {
auto time_begin = std::chrono::high_resolution_clock::now();
VecType vec3 = vec1 + vec2;
auto time_end = std::chrono::high_resolution_clock::now();
auto duration = time_end - time_begin;
double time = static_cast<double>(duration.count())
* static_cast<double>(decltype(duration)::period::num)
/ static_cast<double>(decltype(duration)::period::den);
sum += time;
sum2 += time * time;
}
double mean = sum / static_cast<double>(rep);
double err = std::sqrt((sum2 - sum * sum / static_cast<double>(rep))
/ static_cast<double>(rep * (rep - 1)));
std::cout << std::scientific << std::setprecision(3);
std::cout << " Time: " << mean << " +- " << err << " sec." << std::endl;
}
template <typename VecType>
void test_add3(std::size_t size) {
std::cout << "Test of addition of 3 vectors" << "\n";
std::cout << " Vector Type: " << typeid(VecType).name() << "\n";
std::cout << " Vector Size: " << size << "\n";
std::cout << " Wait..." << std::endl;
VecType vec1(size), vec2(size), vec3(size);
std::mt19937 rand_eng;
std::uniform_real_distribution<double> rand_dist;
for (std::size_t i = 0; i < size; ++i) {
vec1[i] = rand_dist(rand_eng);
vec2[i] = rand_dist(rand_eng);
vec3[i] = rand_dist(rand_eng);
}
double sum = 0.0;
double sum2 = 0.0;
constexpr std::size_t rep = 100;
for (std::size_t i = 0; i < rep; ++i) {
auto time_begin = std::chrono::high_resolution_clock::now();
VecType vec4 = vec1 + vec2 + vec3;
auto time_end = std::chrono::high_resolution_clock::now();
auto duration = time_end - time_begin;
double time = static_cast<double>(duration.count())
* static_cast<double>(decltype(duration)::period::num)
/ static_cast<double>(decltype(duration)::period::den);
sum += time;
sum2 += time * time;
}
double mean = sum / static_cast<double>(rep);
double err = std::sqrt((sum2 - sum * sum / static_cast<double>(rep))
/ static_cast<double>(rep * (rep - 1)));
std::cout << std::scientific << std::setprecision(3);
std::cout << " Time: " << mean << " +- " << err << " sec." << std::endl;
}
}
int main() {
std::cout << "Push enter to start this program" << std::endl;
std::cin.ignore(100, '\n');
constexpr std::size_t size = 10000000;
Test::test_add2<Temp::Vector>(size);
Test::test_add2<Expr::Vector>(size);
std::cout << std::endl;
Test::test_add3<Temp::Vector>(size);
Test::test_add3<Expr::Vector>(size);
std::cout << std::endl;
std::cout << "Push enter to end this program" << std::endl;
std::cin.ignore(100, '\n');
}