LoginSignup
1

Eigen::Tensorのargmaxを実装してみた

Last updated at Posted at 2019-02-08

初めに

テンソルでもmax値とその場所が分かるargmaxを使いたいことが多いですが,~~なぜかargmaxが実装されていなかったので,~~Eigen::TensorのReduceを実装することにより実装しました.
オリジナルのReducerの実装例にもなっていますので,ご覧ください.

実装してみたら,なぜargmaxがないのか理解しました.
既にargmaxの実装がありました.argmaxメソッドをシンプルに使えばインデックスの値が取れますし,Reducerで指定すればインデックスと値の両方同時に取得もできます.testコードに使用例が見つかります.実装の方針は同じでした.

方針

基本的には,要素同士でどういう計算をするかの実装だけしてあげれば,簡単に実装できてしまう構造になっています.素晴らしいです.
ただ,今回は少し込み入った事情があります.

インデックスと値の組を保持したまま,値の方でmaxを取るようなmaxオペレータが必要そうです.残念ながら,プリミティブ型をテンソルの値としたmaxではインデックスの情報が消えてしまいます.かといって,インデックス情報が途中で追加されて出てくる,すなわち入力テンソルの型と出力テンソルの型が異なるようなオペレータ実装は許されていないようです.
そこで,Tensor<ValueType, Rank>ではなく,Tensor<Pair<int,ValueType>, Rank>のmaxを取るオペレータを実装することにしました.どうにも直では無理そうです.
「直では無理」が,「なぜargmaxがないのか」の回答でした.

移し替えのオーバヘッドがかかってしまうのが気になるところです.

実装

まずはTensorに入れる{インデックス,値}のペア構造体を定義.
比較やmaxが取れるようにオペレータを実装しておく.

template<typename T>
struct KeyVal {
	//中身
	int key;
	T val;
	//コンストラクタたち
	KeyVal() : key(-1), val(T(0)) {}
	KeyVal(int _a, T _b) : key(_a), val(_b) {}
	//オペレータたち
	KeyVal& operator=(const KeyVal &right) {
		this->key = right.key;
		this->val = right.val;
		return *this;
	}
	bool operator>(KeyVal &right) const {
		return this->val > right.val;
	}
	friend static KeyVal max(const KeyVal &ta, const KeyVal &tb) {
		return ta.val < tb.val ? tb : ta;
	}
	//表示用
	friend static std::basic_ostream<char, std::char_traits<char>>& operator<<(std::basic_ostream<char, std::char_traits<char>>& ta, const KeyVal &tb) {
		return ta
			<< std::right << std::noshowpos << std::setfill(' ') << std::setw(0) << "("
			<< std::setw(2) << std::setfill(' ') << tb.key
			<< std::right << std::noshowpos << std::setfill(' ') << std::setw(0) << ","
			<< std::setw(2) << std::setfill(' ') << tb.val
			<< std::right << std::noshowpos << std::setfill(' ') << std::setw(0) << ")"
			;
	}
};

Reducerを実装.
Eigen::MaxReducerをヘッダからコピって改造.

//MaxReducerのフォーク
//↑の構造体を入れるためだけのReducer実装
template <typename T>
struct ArgMaxReducer
{
	static const bool PacketAccess = Eigen::internal::packet_traits<KeyVal<T>>::HasMax;
	static const bool IsStateful = false;

	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const KeyVal<T> t, KeyVal<T>* accum) const {
		if (t > *accum) { *accum = t; }
	}
	template <typename Packet>
	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const {
		(*accum) = KeyVal::max(*accum, p); //ここ変更
	}
	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE KeyVal<T> initialize() const {
		return KeyVal<T>(-1, Eigen::internal::MinMaxBottomValue<T, true, Eigen::NumTraits<T>::IsInteger>::bottom_value()); //ここ変更
	}
	template <typename Packet>
	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
		return Eigen::internal::pset1<Packet>(initialize());
	}
	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE KeyVal<T> finalize(const KeyVal<T> accum) const {
		return accum;
	}
	template <typename Packet>
	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet finalizePacket(const Packet& vaccum) const {
		return vaccum;
	}
	template <typename Packet>
	EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE KeyVal<T> finalizeBoth(const KeyVal<T> saccum, const Packet& vaccum) const {
		return KeyVal::max(saccum, predux_max(vaccum)); //ここ変更
	}
};

使ってみる.

int main() {
	//4x4x3の3階のテンソルを作成
	//データ順は後ろから順 デフォルトのColMajorだと前から順
	Eigen::Tensor<int, 3, Eigen::RowMajor> base({ 4,4,3 });

	//データ配置を知るために一次元配列から直コピーしてみる
	//ふつうはTensorMapでいいと思います.
	std::vector<int> baseb(4 * 4 * 3);
	for (int i = 0; i < baseb.size(); ++i) baseb[i] = baseb.size()-i-1;
	memcpy(base.data(), &baseb[0], sizeof(int) * 4 * 4 * 3);

	//表示してみる
	for (int r = 0; r < 4; ++r) {
		for (int c = 0; c < 4; ++c) {
			for (int i = 0; i < 3; ++i) {
				std::cout << base(r, c, i) << ",";
			}
			std::cout << " ";
		}
		std::cout << std::endl;
	}
	std::cout << std::endl;

	//実装済のmaximumでreduceしてみる.2番(0開始なので最終次元)の次元を縮退
	Eigen::Tensor<int, 2, Eigen::RowMajor> reduced = base.maximum(Eigen::array<Eigen::DenseIndex, 1>{2});

	//結果
	std::cout << reduced << std::endl;
	std::cout << std::endl;

	//argmax実行のために要素をインデックス付きにする
	Eigen::Tensor<KeyVal<int>, 3> paired({ 4,4,3 });
	for (int r = 0; r < 4; ++r) {
		for (int c = 0; c < 4; ++c) {
			for (int i = 0; i < 3; ++i) {
				paired(r, c, i) = KeyVal<int>((r*4+c)*3+i, base(r,c,i));
			}
		}
	}

	//実装したオペレータを使ってargmaxのReductionを実行
	Eigen::Tensor<KeyVal<int>, 2, Eigen::RowMajor> reduced_paired = paired.reduce(Eigen::array<Eigen::DenseIndex, 1>{2}, ArgMaxReducer<KeyVal<int>>());

	//結果
	for (int r = 0; r < 4; ++r) {
		for (int c = 0; c < 4; ++c) {
			std::cout << reduced_paired(r, c) << " ";
		}
		std::cout << std::endl;
	}
	std::cout << std::endl;

	return 0;
}

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1