解法の理解に時間がかかったので分かりづらかったところを中心に解説を書きます。公式解説の補足のようなものです。公式解説がすぐ理解できた人にとって読む必要のある部分はありません。ゆるふわ解説です。
問題と公式解説
自分の解答(Python3)
1~16行目: 二項係数の前計算
18~22行目: 入力やdpテーブルの準備
24~27行目: dpテーブルの更新←ここの解説をします
29行目: 出力
解説
公式解説と同じ解法で解きます。
なぜ答えが奇数のときは0通りなのか
ビット毎に考えます。
総 $xor$ が $0$ になるには、任意の $1$ 以上の整数 $i$ について、「下から $i$ 番目のビットが立っている要素」の数が数列全体で偶数個になっていなければなりません。
答えが奇数だと下から $1$ 番目のビットが奇数個立っていることになるので、総 $xor$ が0になりません。
dpテーブルの更新
実装するとこうなりました。
for i in range(2, M+1, 2):
for j in range(0, min(N, i)+1, 2):
dp[i] += dp[(i-j)//2] * cmb(N, j)
dp[i] %= MOD
cmb(N, j)
この関数は $_N C_j$ を $998244353$ で割った余りを返します。
解説の漸化式中の $j$ は「最下位ビットの立っている要素の数」を2で割ったものですが、分かりづらいと思ったので、解答中の $j$ は「最下位ビットの立っている要素の数」そのままとしています。
$dp_i$ は「総和が $i$ 、総 $xor$ が $0$ になる非負整数列の数」です。
※以下、数列というときは総 $xor$ が $0$ かつ要素は全て非負整数とします。
これは言い換えると「最下位ビットが $0$ 個立っている時の総和が $i$ になる数列の数 $+$ 最下位ビットが $1$ 個立っているときの数列の数… $+$ 最下位ビットが $N$ 個立っている時の数列の数」になります。
この「最下位ビットがいくつ立っているか」が $j$ です。$j$ 毎に組み合わせの数を計算し足し合わせます。
「最下位ビットが $j$ 個立っているときの総和が $i$ になる数列の数」というのは、「最下位ビットが一つも立ってない総和が $i-j$ になる数列の数 $\times$ 最下位ビットが立つ $j$ 個の要素の選び方」と同じです。
この「最下位ビットが一つも立ってない総和が $i-j$ になる数列の数」を求めるのがなんとなくめんどくさそうですが、これは全体を $1$ ビット右にずらして考えると「総和が $(i-j) \div 2$ になる数列の数」と同じになります。
例
例えば、$N=5, i=6$ の場合を考えます。
$i=5$ まで計算した時点でdpテーブルは以下のようになります。
$i$ | 0 | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|---|
$dp_i$ | 1 | 0 | 10 | 0 | 15 | 0 |
$j=0$ のとき、「最下位ビットを使わない総和が $6$ になる数列の数」を求めます。これは「総和が $3$ になる数列の数」と同じなので $0$ 通りです。
$j=2$ のとき、「最下位ビットを使わない総和が $4$ になる数列の数」を求めます。これは「総和が $2$ になる数列の数」と同じになので $dp_2$ の値を使います。$N$ 個の要素から最下位ビットが立つ要素を $2$ つ選ぶときの選び方が $10$ 通りなので、$10 \times 10 = 100$ 通りになります。
$j=4$ のときは「総和が $1$ になる数列の数」が $0$ なので $0$ 通りです。
$N$ 個より多く要素を選ぶことは出来ないので $i=6$ のときこれ以上の更新はありません。