LoginSignup
10
6

More than 1 year has passed since last update.

全人類Wavelet Matrixを書こう!

Last updated at Posted at 2023-04-11

はじめに

Wavelet Matrixって色々できて、持っていたらかっこいいですよね?
この記事は厳密なことを気にせずにWavelet Matrix(というか完備辞書)を実装するための記事です。実装は簡潔な代わりに、簡潔性や最悪計算量などが悪化してますが、あまり気にしない人が対象です。また、この記事はWavelet Matrixの機能などを知っている人(つまり実装したいけど、実装がめんどいっていう人)が対象の記事です。

//4/13 追記
この記事の実装には未定義動作が含まれている可能性が高いようなので、コードをコピペして使うことは推奨しません。現在修正中です。

//4/19 追記
記事のコードを大幅に修正しました。

tl; dr

ビット演算と適当1に組み合わせることで、簡単にコンパクト2な完備辞書が実装できます。
完備辞書を用意できたらMitI_7様の素晴らしい記事を参考にして欲しい機能を実装しましょう。

完備辞書

完備辞書とは次の操作ができるデータ構造を指します。(以下bit列の長さをNとします)

  • access(i) : i番目のbitを返す
  • rank(i) : i番目までの立っているbitを返す
  • select(i) : i番目に立っているbitのindexを返す(同様にi番目に立っていないbitのindexを返す関数も用意したほうが良い)

通常の完備辞書は大ブロックと小ブロックを分ける方法を用いて簡潔にしていますが、実装が難しい上にselectをrankを用いた二分探索によってO(logN)で実装することがほとんどですが、ここでは空間を2n + O(1)bitに悪化させる代わりに、簡単にrankを最悪O(1)、selectを平均O(1)、最悪O(logN)で実装します。

先にC++のコードを提示します。

bit_dict.cpp
//空間2n + O(1), rank:O(1), select:平均O(1), 最悪O(log(N))
struct bit_dict{
private:
	using ull = unsigned long long;
	int sz, size_of_memo_0, size_of_memo_1;

	ull* bits;
	//acc[n]->[0, 64n)の立っているbitの数
	int *acc;
	//memo_1[n]->[0, memo_1[n])の立っているbitが64n個であるような最小のn
	int *memo_0, *memo_1;

	inline pair<int, int> index(const int i)const{
		return make_pair(i >> 6, i & 0b111111);
	}

	inline bool access(const int k)const{
		auto [i, j] = index(k);
		return (bits[i] >> j) & 1;
	}

	void build(){
		acc[0] = 0;
		for (int i = 0; i < (sz >> 6); i++){
			acc[i + 1] = acc[i] + __builtin_popcountll(bits[i]);
			size_of_memo_0 += 64 - __builtin_popcountll(bits[i]);
			size_of_memo_1 += __builtin_popcountll(bits[i]);
		}

		size_of_memo_0 >>= 6;
		size_of_memo_1 >>= 6;
		size_of_memo_0 += 2;
		size_of_memo_1 += 2;

		unsigned count = 0, count0 = 0;
		memo_0 = (int*)malloc(sizeof(int) * size_of_memo_0);
		memo_1 = (int*)malloc(sizeof(int) * size_of_memo_1);
		for (int i = 0; i < sz; i++){
			if (access(i) && (count++ & 0b111111) == 0) memo_1[(count - 1) >> 6] = i;
			if (!access(i) && (count0++ & 0b111111) == 0) memo_0[(count0 - 1) >> 6] = i;
		}

		for (int i = (count0 >> 6) + 1; i < size_of_memo_0; i++) memo_0[i] = sz + 1;
		for (int i = (count >> 6) + 1; i < size_of_memo_1; i++) memo_1[i] = sz + 1;
	}

public:
	bit_dict() : sz(0), size_of_memo_0(0), size_of_memo_1(0), bits(nullptr), acc(nullptr), memo_0(nullptr), memo_1(nullptr) {}

	bit_dict(const bit_dict &b) : sz(b.size()), size_of_memo_0(b.size_of_memo_0), size_of_memo_1(b.size_of_memo_1), bits(nullptr), acc(nullptr), memo_0(nullptr), memo_1(nullptr){
		bits = (ull*)malloc(sizeof(ull) * ((sz >> 6) + 1));
		acc = (int*)malloc(sizeof(int) * ((sz >> 6) + 1));
		memo_0 = (int*)malloc(sizeof(int) * (size_of_memo_0));
		memo_1 = (int*)malloc(sizeof(int) * (size_of_memo_1));
		memcpy(bits, b.bits, sizeof(ull) * ((sz >> 6) + 1));
		memcpy(acc, b.acc, sizeof(int) * ((sz >> 6) + 1));
		memcpy(memo_0, b.memo_0, sizeof(int) * (size_of_memo_0));
		memcpy(memo_1, b.memo_1, sizeof(int) * (size_of_memo_1));
	}

	bit_dict(const vector<bool> &b) : sz(b.size()), size_of_memo_0(0), size_of_memo_1(0){
		bits = (ull*)malloc(sizeof(ull) * ((sz >> 6) + 1));
		acc = (int*)malloc(sizeof(int) * ((sz >> 6) + 1));
		memset(bits, 0, sizeof(ull) * ((sz >> 6) + 1));
		memset(acc, 0, sizeof(int) * ((sz >> 6) + 1));

		for (int i = 0; i < sz; i++){
			if(b[i]) bits[index(i).first] += 1ul << index(i).second;
		}

		build();
	}

	bit_dict(const int s) : sz(s), size_of_memo_0(0), size_of_memo_1(0), memo_0(nullptr), memo_1(nullptr){
		bits = (ull*)malloc(sizeof(ull) * ((sz >> 6) + 1));
		acc = (int*)malloc(sizeof(int) * ((sz >> 6) + 1));
		memset(bits, 0, sizeof(ull) * ((sz >> 6) + 1));
		memset(acc, 0, sizeof(int) * ((sz >> 6) + 1));
	}

	//初期化、同じindexについて複数回呼んだ時は未定義
	void init(const int i, const bool b){
		if (b) bits[index(i).first] += 1ul << index(i).second;

		if (i == sz - 1){
			build();
		}
	}

	//[0, k)で立っているbitの個数
	inline int rank(int k)const{
		auto [i, j] = index(k);
		return acc[i] + __builtin_popcountll(bits[i] & ((1ul << (j)) - 1));
	}

	//[l, r)で立っているbitの個数
	inline int rank(int l, int r)const{
		return rank(r) - rank(l);
	}

	//k番目に立っているbitの位置(0-indexed)、存在しない場合は-1を返す
	inline int select(const int k)const{
		int i = index(k + 1).first;

		//rがok, lがng
		int l = memo_1[i] - 1, r = memo_1[i + 1];
		while (r - l > 1){
			int mid = (r + l) >> 1;
			if (rank(mid) >= k + 1) r = mid;
			else l = mid;
		}
		return (r == sz + 1 ? -1 : r - 1);
	}

	//k番目に立っていないbitの位置(0-indexed)、存在しない場合は-1を返す
	inline int select_0(const int k)const{
		int i = index(k + 1).first;

		//rがok, lがng
		int l = memo_0[i] - 1, r = memo_0[i + 1];
		while (r - l > 1){
			int mid = (r + l) >> 1;
			if ((mid - rank(mid)) >= k + 1) r = mid;
			else l = mid;
		}
		return (r == sz + 1 ? -1 : r - 1);
	}

	//[l, size)でk番目に立っているbitの位置(0-indexed)
	inline int select(const int k, const int l)const{
		return select(k + rank(l));
	}

	inline bool operator[](const int k)const{
		return access(k);
	}

	int size()const{
		return sz;
	}
};

まず、データの持ち方は64bit整数の生配列を使っています。__builtin_popcountllが用意されているので、64bit整数の立っているbitを数えるのはO(1)でできるのがミソです。

access

bit演算をするだけです。便利なので、operator[]も定義しておきましょう

access.cpp
inline bool access(const int k)const{
	auto [i, j] = index(k);
	return (bits[i] >> j) & 1;
}

inline bool operator[](const int k)const{
	return access(k);
}

rank

64bit毎に立っているbitの数を保存することで、O(1)でrankが計算できます。

イメージとしては、64の倍数毎の答えはわかるので、あとは端数をマスク処理とpopcount命令でO(1)で計算し、それによってrankをO(1)で計算できます。

rank.cpp
    inline int rank(int k)const{
		auto [i, j] = index(k);
		return acc[i] + __builtin_popcountll(bits[i] & ((1ul << (j)) - 1));
	}

これの補助データ構造で消費する空間は、32bit整数 * N / 64 = N / 2bitです。

select

単にrankを用いて二分探索することではO(logN)になってしまうので、少し工夫してO(1)にします。
nを整数として、64n番目の立っているbitを記録します。

ここで、k番目を知りたいとき、(kよりも小さい最大の64の倍数)番目のbitが立っている場所をl、(kよりも大きい最小の64の倍数)番目のbitが立っている場所をrとして、[l, r)で二分探索をすることで、データが一様ランダムであるという仮定の元で、平均区間長が128になるはずなので3、約lg(128) = 7回の探索でselectを計算でき、平均O(1)を達成できます。

データが偏っている場合でも、区間は高々Nなので、O(logN)で計算できます。

select.cpp
    //k番目に立っているbitの位置(0-indexed)、存在しない場合は-1を返す
	inline int select(const int k)const{
		int i = index(k + 1).first;

		int l = memo_1[i] - 1, r = memo_1[i + 1];
		while (r - l > 1){
			int mid = (r + l) >> 1;
			if (rank(mid) >= k + 1) r = mid;
			else l = mid;
		}
		return (r == sz + 1 ? -1 : r - 1);
	}

立っていないbitについてのselect_0も同様に実装できます。このとき、size_of_memo_1 + size_of_memo_0 ~ N / 644より、補助データ構造で消費する空間はおおむね、32bit整数 * N / 64 = N / 2bitです。空間をもっと消費することで、定数倍高速化が可能です。

空間計算量

メインのデータを保存するのに、N + O(1)bit、補助データ構造にN + O(1)bitを消費しているので、2N + O(1)bitです。

Wavelet Matrix

ここまでできたら、あとはWavelet Matrixを実装するだけで、実装はあまり難しくはありません。5
以下の素晴らしい記事を参考にして、丁寧に実装しましょう。(丸投げ)

おわりに

ここまでお読みいただきありがとうございました。結局WaveletMatrixの記事というよりも完備辞書についての記事になりました。(タイトル詐欺)
色変記事以外の初めての記事なので、拙いところも多いと思います。誤り等ありましたら、ぜひご一報お願いします。

最後にvarify用リンクおよび私の提出を貼っておきます
varify用問題
508ms

  1. 特に深く考える必要がない、ある値の意です

  2. 競プロ文脈においてです

  3. 厳密にはよくわかっていません

  4. 番兵を置いているので、厳密に一致はしません。

  5. HashMapを実装する場合は結構大変です

10
6
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
6