#前書き
東京海上日動2020のE問題とABC172のE問題で出てきたので、上記の二つの問題を解説しつつ包除原理について自分なりにまとめていきます。
また、hamayanhamayanさんの記事に包除原理のまとめがあったので、詳しくはこちらを参照してください。
#包除原理とは
包除原理とは集合$A_i(i=1$~$n)$の和集合の元の個数が以下のように表される原理のことです。また、集合$X$の元の個数を以下では$|X|$と表します。
$$|\cup^{n}_{i=1}A_i|=\sum_{i}{|A_i|}-\sum_{i<j}{|A_i \cap A_j|}+…+(-1)^{n-1}|A_1 \cap … \cap A_n|$$
この式でのポイントは、k個の集合の積集合を考えると、kが奇数の時は+で偶数の時は-になることです。実装の際に重要になるので問題の解説で触れています。
また、上式の証明はここでは触れませんが、こちらの記事の$n=3$での図をみると理解が深まると思うので参考にしてください。
#包除原理への道筋
包除原理は具体的には以下の二つのパターンに当てはまる場合に使うと良いと考えられます。
集合$A_i(i=1$~$n)$がそれぞれある条件を満たすような集合とした時、
①それらの条件のうち少なくとも一つを満たす集合の要素数を求めたい場合
②それらの条件を一つも満たさない集合の要素数を求めたい場合(①の余集合を考える)
また、包除原理で積集合を考える際にはその積集合が成り立つ条件のみを考慮すればよくそれ以外の条件は満たしていても満たしてなくても良いことに注意が必要です。
#包除原理の実装
まずは簡単なABCの問題から解説を書きます。
$1 \leqq A_i,B_i \leqq M$のもとで「$A_i \neq B_i(1 \leqq i \leqq N)$…①」かつ「$A_i \neq A_j(1 \leqq i <j \leqq N)$かつ$B_i \neq B_j(1 \leqq i <j \leqq N)$…②」が成り立つような数列$A,B$の組を求める問題です。
②についてはその数列の要素が全て異なるような場合の数で、一つの数列については$_MP_N$となります。したがって、条件が②のみであれば$(_MP_N)^2$通りになります。しかし、①の条件があるので一部のパターンは許容されません。ここでは、数列Aと数列Bを両方動かすと考えにくいので片方の数列Aを固定して考えます。まずBを貪欲に動かして考えますが、なかなか難しいことが実験よりわかります。
このもとで何が難しいのかを言語化することができれば、包除原理に至ることができます。つまり、「$A_i = B_i$が成り立つ$i$が存在する可能性があり、貪欲法では排除するのが難しい」ということです。したがって、$A_i=B_i$を一つも満たさないような数列$A,B$の組み合わせを求めたいことと同値で、これは先ほどの包除原理の道筋の②に当てはまります。
実は、包除原理は思いつきさえすれば実装自体は難しくありません。道筋の②に従って、余事象である$A_i = B_i$が少なくとも一つ成り立つ数列$A,B$の組み合わせを考えます。この時、$A_i = B_i$が成り立つ$i$の選び方は$2^N$通りで、bit全探索と同様に立っているビット$i$で$A_i = B_i$が成り立つとし、**奇数の場合はその場合の数を+でカウントし偶数の場合はその場合の数を-**でカウントするとすれば、包除原理に則りつつ$O(N2^N)$での探索が可能です。しかし、$N$が最大で$5 \times 10^5$なのでこのままでは探索が間に合いません。
ここで、実験を行って数列の組み合わせの数は$A_i=B_i$が成り立つ$i$の個数$k$のみに依存し、それぞれの数が何かには依存しないことがわかったので、$A_i=B_i$が成り立つ$i$の個数$k$に注目してその場合の数を考えます。まず、$A_i=B_i$が成り立つ$k$個の$i$の位置の決め方は$_NC_k$通りで、残りの項は選んでいない$M-k$個の数のうちの$N-k$個を並べれば$_{M-k}P_{N-k}$通りとなります。したがって、数列$A$が固定され$A_i=B_i$が成り立つ$i$が$k$個である時、$_NC_i \times _{M-k}P_{N-k}$通りの数列$A,B$の組み合わせがあります。
よって、$k=1$~$N$について上記の計算をそれぞれ$O(1)$で行い、最後に数列$A$の組み合わせの数$_MP_N$をかけることで$O(N)$答えを求めることができます。
以上を実装して以下のようになります。また、上記では組み合わせ計算を$O(1)$で求められることを暗黙的に利用していますが、これはこの記事のmodintライブラリ中の組み合わせ計算で前計算を行っています。使用する際にはCOMinit()
を忘れずにしてください。
C++のコード
//インクルード(アルファベット順)
#include<algorithm>//sort,二分探索,など
#include<bitset>//固定長bit集合
#include<cmath>//pow,logなど
#include<complex>//複素数
#include<deque>//両端アクセスのキュー
#include<functional>//sortのgreater
#include<iomanip>//setprecision(浮動小数点の出力の誤差)
#include<iostream>//入出力
#include<iterator>//集合演算(積集合,和集合,差集合など)
#include<map>//map(辞書)
#include<numeric>//iota(整数列の生成),gcdとlcm(c++17)
#include<queue>//キュー
#include<set>//集合
#include<stack>//スタック
#include<string>//文字列
#include<unordered_map>//イテレータあるけど順序保持しないmap
#include<unordered_set>//イテレータあるけど順序保持しないset
#include<utility>//pair
#include<vector>//可変長配列
using namespace std;
typedef long long ll;
//マクロ
//forループ関係
//引数は、(ループ内変数,動く範囲)か(ループ内変数,始めの数,終わりの数)、のどちらか
//Dがついてないものはループ変数は1ずつインクリメントされ、Dがついてるものはループ変数は1ずつデクリメントされる
#define REP(i,n) for(ll i=0;i<(ll)(n);i++)
#define REPD(i,n) for(ll i=n-1;i>=0;i--)
#define FOR(i,a,b) for(ll i=a;i<=(ll)(b);i++)
#define FORD(i,a,b) for(ll i=a;i>=(ll)(b);i--)
//xにはvectorなどのコンテナ
#define ALL(x) x.begin(),x.end() //sortなどの引数を省略したい
#define SIZE(x) ll(x.size()) //sizeをsize_tからllに直しておく
//定数
#define INF 1000000000000 //10^12:極めて大きい値,∞
#define MOD 1000000007 //10^9+7:合同式の法
#define MAXR 600000 //10^5:配列の最大のrange(素数列挙などで使用)
//略記
#define PB push_back //vectorヘの挿入
#define MP make_pair //pairのコンストラクタ
#define F first //pairの一つ目の要素
#define S second //pairの二つ目の要素
#define Umap unordered_map
#define Uset unordered_set
template<ll mod> class modint{
public:
ll val=0;
//コンストラクタ
modint(ll x=0){while(x<0)x+=mod;val=x%mod;}
//コピーコンストラクタ
modint(const modint &r){val=r.val;}
//算術演算子
modint operator -(){return modint(-val);} //単項
modint operator +(const modint &r){return modint(*this)+=r;}
modint operator -(const modint &r){return modint(*this)-=r;}
modint operator *(const modint &r){return modint(*this)*=r;}
modint operator /(const modint &r){return modint(*this)/=r;}
//代入演算子
modint &operator +=(const modint &r){
val+=r.val;
if(val>=mod)val-=mod;
return *this;
}
modint &operator -=(const modint &r){
if(val<r.val)val+=mod;
val-=r.val;
return *this;
}
modint &operator *=(const modint &r){
val=val*r.val%mod;
return *this;
}
modint &operator /=(const modint &r){
ll a=r.val,b=mod,u=1,v=0;
while(b){
ll t=a/b;
a-=t*b;swap(a,b);
u-=t*v;swap(u,v);
}
val=val*u%mod;
if(val<0)val+=mod;
return *this;
}
//等価比較演算子
bool operator ==(const modint& r){return this->val==r.val;}
bool operator <(const modint& r){return this->val<r.val;}
bool operator !=(const modint& r){return this->val!=r.val;}
};
using mint = modint<MOD>;
//入出力ストリーム
istream &operator >>(istream &is,mint& x){//xにconst付けない
ll t;is >> t;
x=t;
return (is);
}
ostream &operator <<(ostream &os,const mint& x){
return os<<x.val;
}
//累乗
mint modpow(const mint &a,ll n){
if(n==0)return 1;
mint t=modpow(a,n/2);
t=t*t;
if(n&1)t=t*a;
return t;
}
//二項係数の計算
mint fac[MAXR+1],finv[MAXR+1],inv[MAXR+1],perm[MAXR+1];
//テーブルの作成
void COMinit() {
fac[0]=fac[1]=1;
finv[0]=finv[1]=1;
inv[1]=1;
FOR(i,2,MAXR){
fac[i]=fac[i-1]*mint(i);
inv[i]=-inv[MOD%i]*mint(MOD/i);
finv[i]=finv[i-1]*inv[i];
}
perm[0]=1;
REP(i,MAXR){
perm[i+1]=mint(i+1)*perm[i];
}
}
//演算部分
mint COM(ll n,ll k){
if(n<k)return 0;
if(n<0 || k<0)return 0;
return fac[n]*finv[k]*finv[n-k];
}
mint PERM(ll n,ll k){
return COM(n,k)*perm[k];
}
signed main(){
COMinit();
ll n,m;cin>>n>>m;
//aが何通り
mint ans=PERM(m,n);
mint ans_sub=0;
FOR(i,1,n){
if(i%2){
ans_sub+=(COM(n,i)*PERM(m-i,n-i));
}else{
ans_sub-=(COM(n,i)*PERM(m-i,n-i));
}
}
//余事象
cout << ans*(PERM(m,n)-ans_sub) << endl;
}
オーバフローの原因が分からずに二日間もかかって地獄でしたが、teratailで質問に回答していただき助かりました…。
まず、$S$でも$T$でも1になるようなビットは1しかありえないことに気づいたので、それぞれのビットについて論理積が$S$で論理和が$T$であることの言い換えをします。また、$S,T$の$i$ビット目をそれぞれ$S_i,T_i$とし、$(S_i,T_i)=(0,0),(0,1)(1,0),(1,1)$の4通りの場合分けをします。
[1]$(S_i,T_i)=(0,0)$の時、$i$ビット目が0の数のみが候補なので$i$ビット目が1の数は$A$から除く。
[2]$(S_i,T_i)=(0,1)$の時、$i$ビット目が0の数と1の数のどちらも$A$の候補としてありうるので$A$に入れたままにする。
[3]$(S_i,T_i)=(1,0)$の時は$S,T$が矛盾しているので、0を出力してプログラムを終了する。
[4]$(S_i,T_i)=(1,1)$の時は$i$ビット目が1の数のみが候補なので$i$ビット目が0の数は$A$から除く。
以上四つを行うと、題意を満たすように選べる数のみが$A$に残り、[2]の$(S_i,T_i)=(0,1)$を満たす全てのビットについてそれぞれ0と1になる数が少なくとも一つずつある場合(✳︎)を考えれば良いことになります。
ここで、0と1が少なくとも一つずつあるという条件は否定を考えて0のみまたは1のみあるとした方が考えやすく、(✳︎)は「$(S_i,T_i)=(0,1)$を満たす全てのビットについて0のみまたは1のみとなるものが存在しない場合」と言えます。したがって、これは包除原理の道筋②なので、余事象である「$(S_i,T_i)=(0,1)$を満たす全てのビットについて0のみまたは1のみとなるものが少なくとも一つ存在する場合」を考えれば良いです。
ここで、包除原理で条件が成り立つビットを見る(積集合を考える)時、そのビットは0のみまたは1のみになるという条件が成り立ちます。ここで、0になる数の集合と1になる数の集合は排反かつ和集合が全体集合であることに注目すると、包除原理で条件が成り立つビットをみる時に集合を分割するというイメージで行うことができます。例えば、初期状態では$A_1,A_2,…,A_n$を内包する集合を考えますが、あるビットで0になる$A$の部分集合と1になる$A$の部分集合を保存しておけば、どちらに含まれるかで二つに分割することができます(分割された集合はいずれも0のみまたは1のみという条件を満たします。)。これを条件が成り立つそれぞれのビットに対してそれぞれ行って最終的に分割されて残った集合について数の1個~k個の選び方を考えれば解を求めることができます。
以上より、包除原理の積集合の全探索で高々$2^{18}$回の探索かつそれぞれのビット(高々18ビット)で分割操作は高々$N$回なので、合計で高々$2^{18} \times 18 \times N$回の計算を行えばよく十分に高速なプログラムを書くことができます。また、集合にどの要素が含まれるかをビットで記録することで積集合を&で代替し高速化しています。
以上で方針の概要は説明したので、細かいところを説明していきます。
下記のコードにおいて(1)~(7)で表示している部分に対応した形での説明を行うので注意してください。
(1)
初めの場合分けの部分を実装しますが、nとkがそれぞれ変わりうるので注意してください。
(2)
[2]のパターンで0のみまたは1のみしかないビットが存在する場合は題意を満たすような数字の選び方が存在しないので0を出力して終了します。また、このようなパターンを除かないと、下記のようなパターンで0を出力するべきにも関わらず1を出力してしまいます。
2 1 0 3
2 4
また、[2]のパターンのビットが0である数の集合と1である数の集合をそれぞれ保存するために、どの要素が集合に含まれるかをビットで表す整数をpairで保存します。この際の左シフト演算で1LL<<j
とLLを付けないとint型となってオーバフローする可能性があるので注意が必要です。
(3)
最後に分割されて残った集合それぞれに対して1個~k個選ぶ場合の組み合わせ計算を$O(1)$で求める必要があります。つまり、一つの集合の元の数を$x$として、$_xC_1+_xC_2+…+_xC_{min(k,x)}$を$O(1)$で計算する必要がありますが、それぞれの組み合わせ計算をパスカルの三角形で求めた後に累積和を計算することで前計算を行うことができます。(0-indexedで$pcalc[i][j]:=_iC_1+_iC_2+…+_iC_j$となります。また、パスカルの三角形による組み合わせ計算はこちらを参考にすると良いです。)
(4)
最後にビット全探索の要領で包除原理の計算を行います。
(5)
包除原理の計算における集合の分割では、(整数で表される)分割された集合族はdequeに保存し、積集合をとるビットではdequeに保存されているそれぞれの集合について先頭から&演算を行って集合が空集合でない場合は末尾に再度挿入することを繰り返せば良いです(dequeを使ったのは挿入と削除が$O(1)$でできるからです。)。
(6)
(5)において全てのビットのチェックを終えたら、積集合を取った元での集合族がそれぞれ整数の形で得られます。ここで、それぞれの集合の中からの1個~k個の選び方を考えますが、(3)での前計算によりそれぞれの集合の元の個数がわかれば良いです。集合を整数で表現しておりそれぞれのビットがその要素が含まれるかに対応するので、1になるビットがいくつあるかを考えれば良いです。これは、__builtin_popcount関数で得ることができます。また、この際に**__builtin_popcount関数の引数はunsigned int型でオーバフローが起きる可能性がある**ので、__builtin_popcountll関数を使う必要があります。よって、$s\_sub$を__builtin_popcountllで求めた値とすれば、$pcalc[s\_sub][min(k,s\_sub)]$として$O(1)$でその積集合の組み合わせの数を求められますが、使った条件の個数の偶奇で+か-かが決まるので注意が必要です。
✳︎c++20であればでpopcount関数が実装されているようでオーバフローを気にせずに使えますが、現在のAtCoderの環境では使えないようです。
(7)
また、全集合の場合の数は$pcalc[n][k]$であることから、(6)までで求めた場合の数を$pcalc[n][k]$から引いたものを出力すれば解を求めることができます。
C++のコード
//bit演算の右側はllでないと型変換されない
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
//マクロ
//forループ関係
//引数は、(ループ内変数,動く範囲)か(ループ内変数,始めの数,終わりの数)、のどちらか
#define REP(i,n) for(ll i=0;i<(ll)(n);i++)
#define FOR(i,a,b) for(ll i=a;i<=(ll)(b);i++)
//xにはvectorなどのコンテナ
#define ALL(x) (x).begin(),(x).end() //sortなどの引数を省略したい
#define SIZE(x) ((ll)(x).size()) //sizeをsize_tからllに直しておく
//定数
#define INF 1000000000000 //10^12:極めて大きい値,∞
#define MOD 1000000007 //10^9+7:合同式の法
#define MAXR 100 //10^5:配列の最大のrange(素数列挙などで使用)
//略記
#define PB push_back //vectorヘの挿入
#define MP make_pair //pairのコンストラクタ
#define F first //pairの一つ目の要素
#define S second //pairの二つ目の要素
signed main(){
//入力の高速化用のコード
ios::sync_with_stdio(false);
cin.tie(nullptr);
ll n,k,s,t;cin >> n >> k >> s >> t;
vector<ll> b(n);REP(i,n)cin >> b[i];
vector<ll> a;
//(1)
REP(i,n){
bool f=true;
REP(j,18){
bool s_,t_,b_;s_=(s>>j)&1;t_=(t>>j)&1;b_=(b[i]>>j)&1;
if(s_ and t_){
if(!(b_)){
f=false;break;
}
}else if(!(s_) and !(t_)){
if(b_){
f=false;break;
}
}else if(s_ and !(t_)){
cout<<0<<endl;return 0;
}
}
if(f)a.PB(b[i]);
}
n=SIZE(a);
k=min(n,k);
//(2)
vector<pair<ll,ll>> bits;
REP(i,18){
ll f1=0;ll f2=0;
REP(j,n){
if((a[j]>>i)&1){
f1+=(1LL<<j);
}else{
f2+=(1LL<<j);
}
}
if(f1 and f2){
bits.PB(MP(f1,f2));
}else if(!((s>>i)&1) and (t>>i)&1){
cout<<0<<endl;return 0;
}
}
ll bits_len=SIZE(bits);
//(3)
vector<vector<ll>> pcalc(n+1,vector<ll>(k+1,0));
pcalc[0][0]=1;
FOR(i,1,n){
pcalc[i][0]=1;
FOR(j,1,k){
pcalc[i][j]=pcalc[i-1][j-1]+pcalc[i-1][j];
}
}
REP(i,n+1){
pcalc[i][0]=0;
REP(j,k){
pcalc[i][j+1]+=pcalc[i][j];
}
}
//(4)
ll ans_sub=0;
deque<ll> all_sets;
FOR(i,1,(1LL<<bits_len)-1){
//(5)
all_sets.PB((1LL<<n)-1);
ll odd_even=0;
REP(j,bits_len){
if((i>>j)&1){
odd_even+=1;
ll s=SIZE(all_sets);
REP(_,s){
ll ins=*(all_sets.begin());
all_sets.pop_front();
ll f1=ins&(bits[j].F);ll f2=ins&(bits[j].S);
if(f1){all_sets.PB(f1);}
if(f2){all_sets.PB(f2);}
}
}
}
//(6)
ll s=SIZE(all_sets);
REP(_,s){
ll now=*(all_sets.begin());
all_sets.pop_front();
ll s_sub=__builtin_popcountll(now);
if(odd_even%2){
ans_sub+=pcalc[s_sub][min(k,s_sub)];
}else{
ans_sub-=pcalc[s_sub][min(k,s_sub)];
}
}
}
//(7)
cout<< pcalc[n][k]-ans_sub <<endl;
}