ABC405 E - Fruit Lineup
https://atcoder.jp/contests/abc405/tasks/abc405_e
コンテスト中に考えて惜しいところまで行けたものの、解ききれませんでした。結局終了後に正解できました。悔しかったので記事にします。
コンテスト中の自分の動き
D問題に時間がかかり、終わった時には45分が経過していました。この時点で予測パフォーマンスはかなり悪かったです。D問題を正解できたものの、無駄にややこしく考えてしまいました。遠回りになっただけでなくコードが複雑になり書くのも遅くなりました……。それでも50分以上残っていたのでE問題を一生懸命考えました。
最初は出力例1を見ながら「要するに (AとB)(CとD)に分かれるんだから $ \frac{(A+B)!}{A!B!} \times \frac{(C+D)!}{C!D!} $ で、あとは ACBD のパターンの1通りを足せばいいのでは?」などと明後日の解法を考えました。当然ながら入力例2が合いません。悩みました。次に簡単な例を作って法則を見出せないかと思い、愚直に作って並べてみました。急がば回れということでわざわざ以下のコードを書きました。
import itertools
X = ["A", "A", "B", "B", "C", "D", "D"]
ans = set()
for patttens in itertools.permutations(X):
Y = dict()
for i, p in enumerate(patttens):
if p not in Y:
Y[p] = {i}
else:
Y[p].add(i)
flag = True
for a in Y["A"]:
for c in Y["C"]:
if a > c:
flag = False
for a in Y["A"]:
for d in Y["D"]:
if a > d:
flag = False
for b in Y["B"]:
for d in Y["D"]:
if b > d:
flag = False
if flag:
ans.add(patttens)
ans = list(ans)
ans.sort()
for an in ans:
print(an)
print(len(ans))
しかしお察しの通り何もアイデアは出てきませんでした。A,B,C,D が頭の中で入り乱れて頭が混乱してきました。
ここで「どうやら楽に計算する方法はなさそうだ」と思い直し、とりあえず条件の通りに素直にABCDを並べていくことにしました。で、考えていったら入力例2について手計算したものと結果が一致しました。法則さえわかればあとは一般化して数式にして実装するだけです。
しかしこの時点で残り時間が10分を切っていました。冷静さを欠いていたこともあり結局そのまま解けずに終わりました。終わった後に20分ぐらい考えてようやく正解できました。いっそDを飛ばしてEをやっていればあるいは本番中に間に合っていたかも?というところです。あまり現実的な話ではないですね。
考察
さて、では本番中終わりの方に考えた流れを書いていきます。まず並べられた条件を上から見ると A < C です(果物名で書くとわかりづらいのでアルファベットで表現します)。入力例2の場合(Aが1個、Cが4個)で表現すると以下のようになります。全ての A は全ての C より左にあるので、AとCに関しては必ずこの並び方になるはずです。
次に、ここの隙間にB, Dを入れていくことを考えます。それぞれの隙間には複数の果物を入れてもいいし、全く入れないところがあっても構いません。
ここで条件 B < D を見ます。これを満たすような入れ方は以下のようになります。Bを左から敷き詰め、Dを右から敷き詰めます。境目の1箇所のみB,Dが同居しても構いません。この境目における並べ方は当然 BBBB......BDDDD.....D の1通りです。
最後に A < D の条件がありますから、Dを配置できる範囲は限られてくることがわかります。Aよりも左にDが来てはいけません。ということで、Dをどの範囲まで配置するかによって場合分けをし、それぞれの答えを足し合わせることを考えます。
計算
それでは実際に計算してみましょう。引き続き入力例2の場合を元に考えていきます。
まずDが右から5箇所までに配分される場合を考えます。ただし、5箇所の隙間の中の左端の1箇所が空だと次以降のパターン(右から4箇所まで、3箇所まで、……)と区別ができなくなるので、必ず左端には1個は入っているものとします。なので、まず左端にDを1個入れ、残りの7個を5箇所の隙間に分配して入れていくことを考えます。
7個の果物を5個の箱に分配する場合の数の求め方ということになりますが、これはよく知られているように(※)「5個の箱」を「4個の仕切り」とみなすことで、7個の果物と4個の仕切りとを並べ替える問題になります。したがって $ {}_{11} C_4 = 330 $ 通りです。
同様に、Bは2個の果物を2箇所に分配する問題になりますから仕切り1個、果物2個で $ {}_{3} C_2 = 3$ 通りです。
つまり D を右から5箇所までに配置するときの場合の数はこれらを掛けた 990 通りになります。
(※ ここではさらっと書いていますが、本番中は頭が混乱しておりこれの求め方がなかなか出てきませんでした。自分の頭の悪さに辟易しますね。)
同様にDが右から4箇所に配分される場合を考えます。右から4つめにDを1個入れ、残り7個を4箇所の隙間に入れていきます。果物7個、仕切り3個なので $ {}_{10} C_3 = 120 $ 通りです。
Bの方は2個を3箇所に分配するので $ {}_{4} C_2 = 6 $、掛けて 720通りです。
Dが右から3箇所の場合は $ {}_{9} C_2 = 36 $ 通り。
Bは $ {}_{5} C_2 = 10 $ 通り。掛けて 360 通りです。
Dが右から2箇所の場合は $ {}_{8} C_1 = 8 $ 通り。
B は $ {}_{6} C_2 = 15 $ 通り。掛けて120 通りです。
Dが右から1箇所の場合は $ {}_{7} C_0 = 1 $ 通り。
B は $ {}_{7} C_2 = 21 $ 通り。掛けて 21 通りです。
これらを全て足し合わせると $ 990 + 720 + 360 + 120 + 21 = 2211 $ 通りとなり出力例2に一致します。
nCr を高速に求めるために
$ {}_{n} C_r $ を 998244353 で割った余りを高速に求める方法ですが、これはWeb上のいろんなところに解説記事が書かれていると思いますので説明は省略します(上手に説明できないので逃げます、すみません)。
簡単にいうと
$$ {}_{n} C_r = \frac{n!}{(n-r)! \times r!} $$
ですから、大きな数字の階乗の値(を 998244353 で割った余り)が必要になります。ここで、大きな数字の階乗の値を求めるには時間がかかります。そこで前計算しておくことにします。
この式の右辺を計算するにあたって「(n-r)! で割る」ことと「r! で割る」ことはそれぞれ「(n-r)! の逆元を掛ける」ことと「r! の逆元を掛ける」ことに等しいです。ですから階乗の逆元の値も前計算して用意しておきます。Python の pow 関数を使えばすぐにできます。
実装
入力例2を用いて解いたやり方を一般化し、どんなA, B, C, D についても解けるようにコードを書きます。以下のコードに付けたコメントを見ていただくとわかりますが、入力例2を解いた過程を見ながら式を作っていってます。こんなやり方ではなく、ちゃんと一から数式を作れるようになりたいですね……。
要は、Dが右から何個目のCの左にまで来るか?によって場合分けをしています。つまり計算量はCの大きさ次第で決まりますね。
# nCr % MOD を高速で求めるプログラム
import math
MOD = 998244353
N = 5 * 10 ** 6
# 階乗を求めておく (MOD)
power_list = [1, 1] # 0! = 1 と考える
for i in range(2, N+1):
power_list.append((power_list[-1] * i) % MOD)
#print(power_list)
# 階乗の逆元を求めておく (MOD)
r_power_list = [1]
for i in range(1, N+1):
r_power_list.append(pow(power_list[i], -1, MOD))
#print(r_power_list)
# nCr = n! / (r! * (n-r)!)
# MOD を考えるので、 r! で割るのは r! の逆元を掛けるのに等しくなる。よって
# nCr = n!(MOD) * 逆r!(MOD) * 逆(n-r)!(MOD) となる。
def ncr(n, r):
# nCr = math.factorial(n) / (math.factorial(r) * math.factorial(n - r))
nCr = (power_list[n] * r_power_list[r] * r_power_list[n - r]) % MOD
return nCr
A, B, C, D = map(int, input().split())
#A, B, C, D = 1, 2, 4, 8
ans = 0
# 左から A+1個をBエリア、右からC個をDエリア: 3C2, 11C7 => ncr(A+1-1+B , B) * ncr(C+D-1, D-1)
# 左から A+2個をBエリア、右からC-1個をDエリア: 4C2, 10C7 => ncr(A+2-1+B , B) * ncr(C-1+D-1, D-1)
# 左から A+3個をBエリア、右からC-2個をDエリア: 5C2, 9C7 =>
# 左から A+4個をBエリア、右からC-3個をDエリア: 6C2, 8C7 =>
# 左から A+5個をBエリア、右からC-4個をDエリア: 7C2, 7C7 =>
for i in range(C+1):
a = A+1+i
c = C-i
ans += ncr(a+B-1, B) * ncr(c+D-1, D-1)
ans %= MOD
print(ans)
"""
ans += ncr(11, 4) * ncr(3, 2)
ans += ncr(10, 3) * ncr(4, 2)
ans += (ncr(9, 2) * ncr(5, 2))
ans += (ncr(8, 1) * ncr(6, 2))
ans += (ncr(7, 0) * ncr(7, 2))
print(ans)
"""