ABC356_Masked Popcount
なるべく、数式とか使わずに直感でわかるように説明していきたいです。
問題文
リンク:https://atcoder.jp/contests/abc356/tasks/abc356_d
整数 N,M が与えられるので、
∑{K=0〜N} popcount(k&M)
を 998244353 で割った余りを求めてください。
※& はビット単位AND演算を表します。
制約
0 ≦ N < 60
0 ≦ M < 2^60
popcount(k)
kを2進数にしたとき、1が立っている桁は何個あるか?ということです。
必要な知識
・進数計算(10→2)
・論理演算(AND)
・popcount公式
本解説
問題文の要約
K&Mにより、Mが0の桁は常に強制的に0にされてしまいます。
0にされるということはその桁の1が数えられることはありません。
毎回Mの&演算が行われるのでΣにおける足し算すべてで同様です。
つまり、
①:まずはMは一切考えずに、Nまでの数において各桁ごとに 「1」 が出現する回数を求める。
②:Mで1が立っている桁に対応する箇所のみを足し合わせる。これが答え。
で求めることができます。
検討:ブルートフォースでいけそうか
与えられる数は2^60なので無理でした。
最悪、足し算を2^60回することになりますからね。
popcountの公式を使う
突然ですが、1111(2)までを0000から順番に書き並べてみます。
0000
0001
0010
0011
0100
0101
0110
0111
1000
1001
1010
1011
1100
1101
1110
1111
各桁ごとに登場する 1 の回数を数えるとすべての桁で同じ 8個 です。
同様に、0から11111(2)までに各桁ごとに登場する1の回数はすべての桁で同じ数であり、その数は16個です。
0から2^nで最上位桁を除く各桁ごとに1が登場する回数は2^(n-1) ・・・①
例えば0から2^100の数で最上位桁未満の各桁で1が登場する回数は2^99となります。
これを用いることで一気に計算量を減らして各桁ごとの「1」の出現回数を求めることができます。
Nを2進数にして最上位ビットから右に向かってそれぞれのビットに以下の操作を繰り返します。
・そのビットが1の場合→
⭐︎そのビットに1が立つまでの下位ビットで1が現れる回数を①を用いて計算し、下位ビットそれぞれの「1」出現回数保持配列に足す。
⭐︎そのビットに1が出現する回数は右のビットを10進数変換して+1した数に等しいので、それを計算して現在地点のビットの「1」出現回数保持変数に足す。
・そのビットが0の場合→何もしない。
言葉で説明すると難しいので具体例を考えてから一般化していきます。
具体例で考える
N = 10 M = 3 とします。
Nb = 1010 、Mb = 0011 です。
各ビットごとに登場する1の回数を保持する変数を一番左を1として、A1,A2,A3,A4としましょう。
一番右のビットは1です。よって0〜111を①の公式で求めてA2〜A4に足します。
また、A1には、010を10進数に変換して+1したものを足します。
2番目のビットは0ですので何もしません。
3番目のビットは1です。よって0〜01を①の公式で求めてA4に足します。
4番目、つまり最後のビットは0です。よって何もしません。
これで各桁ごとの「1」の出現回数が求まりました。
Mは0011なのでこのうちA3,A4のビットだけが総和の対象となります。
よって、A3,A4を足して終わりです。
実際の問題ではmodで答えを出すので各操作ごとにmodで数を小さくするなどの工夫が必要です。
間に合う?計算量は?
Nは2^60なのでビット数は最大60です。
①の冪乗計算は繰り返し二乗法を使うことでlogに落とせます。(2^60程度ならば繰り返し2乗法使わなくてもいけます)
よってNの桁をdとして Σlog(d) となります。
最大ケースのd=60でも大したことありません。
コード
上記を元にPythonで書いたコードになります。
少し冗長ですが可読性を意識して書いてます。
n,m = map(int, input().split())
n_b = format(n,'b') #Nの二進数
n_l = len(n_b)
m_b = format(m,'b') #Mの二進数
m_l = len(m_b)
# 各桁ごとに1が出現する回数を計算
pop_arr = [0]*n_l #各桁ごとの「1」の出現回数保持配列
for i in range(n_l):
tmp_bit = n_b[i]
if tmp_bit == '0':
continue
elif tmp_bit == '1':
digit = n_l - i - 2
if digit < 0:
pop_arr[i]+=1
continue
temp = 2**digit
temp = temp % 998244353
pop_arr[i+1:] = [x + temp for x in pop_arr[i+1:]]
#現在のbitより右の部分を足す
r_bit = n_b[i+1:]
r_int = int(r_bit,2)
pop_arr[i] += r_int + 1 #全て0のパターンも個数に含むので+1
# mの1が立っている桁だけの総和を取る
total = 0
for j in range(m_l):
if j > n_l - 1: #Nの桁数を超えたら終了(どうせ0)
break
maskbit = m_b[m_l - j - 1]
if maskbit == '0':
continue
elif maskbit == '1':
total += pop_arr[n_l - j - 1]
total %= 998244353
#答えを出力
print(total)