緑色中くらい(rating:934)だけど、青色問題解けたよー!! の喜びを分かち合いたい&数週間後の自分に向けてのメモ。
TL;DR
遷移の回数がすごい多い( $ 10^{18} $とか)DPを解きたい場合、
- DPの遷移式を行列形式にする
- 1.で求めた行列を遷移回数分累乗してから、行列積求める
で解ける!
問題
公式解説
…こちとら緑の中くらいだぞ!! わかるか、ちくしょー!! となりました。
冒頭の
1文字ずつ文字を伸ばしていくことを考えると、末尾のたかだか L=6 文字を状態としてもつ DP を考えることができます。
だけ何とか理解。「"行列累乗"って何?」というレベル。ここからいろいろ調べてACできた。
解説
入力読み込み
ひとまず入力を読み込む
N,M = map(int, input().split())
S = set()
for _ in range(M):
s = str(input())
S.add(s)
6文字以下の場合を考えてみる
次に、ある文字列TがSに含まれる文字列sを部分文字列として含まない通り数を考える。
sは制約により長さ6以下なので、ひとまず長さ6文字での全通りのパターンに対して考えてみる。
sを含んでいるか否かは.count(s)が1以上なら含んでる、0なら含んでないと調べられる。
Tに含まれるのはa,bのみで全パターンを調べる必要があるので、ここはbit全探索を使う。
まずはbit全探索しやすいように入力を修正。
N,M = map(int, input().split())
S = set()
for _ in range(M):
s = str(input())
s = s.replace('a', '0').replace('b', '1') # ←追加行。扱いやすいようaを0、bを1に置換する。
S.add(s)
そして、bit全探索する。
pattern = dict() # keyの文字列を取れる通り数
for i in range(2 ** min(6,N)):
str_i = str(bin(i))[2:] # bin(i)で2進数にして、strで文字列形式にして、[2:]で頭の"0b"を削る
str_i = str_i.zfill(min(6,N)) # iを6桁(or N桁のうち少ない方の桁数)の01の組み合わせにしたもの
for s in S:
if str_i.count(s) != 0:
pattern[str_i] = 0
break
else:
pattern[str_i] = 1
$N <= 6$の場合は、各文字列での通り数を合計すれば良いので、こうなる。
if N <= 6:
ans = 0
for k,v in pattern.items():
ans += v
print(ans)
exit()
これで入力例1はクリアー♪
7文字以上の場合を考えてみる
7文字で末尾がaaaaaaとなり、かつsが含まれない通り数を考えてみる。
6文字以下の場合で考えた結果を利用すると、
- baaaaaのときにaが入る場合(b aaaaaa)
- aaaaaaのときにaが入る場合(a aaaaaa)
の和となる。但し、もちろんaaaaaaがSに含まれていたら0通り。
同様に7文字で末尾がaaaaabとなる通り数は、末尾がaaaaaaの通り数 + 末尾がbaaaaaの通り数。
同様に7文字で末尾がaaaabaとなる通り数は、末尾がaaaaabの通り数 + 末尾がbaaaabの通り数。
(略)
同様に7文字で末尾がbbbbbaとなる通り数は、末尾がabbbbbの通り数 + 末尾がbbbbbbの通り数。
同様に7文字で末尾がbbbbbbとなる通り数は、末尾がabbbbbの通り数 + 末尾がbbbbbbの通り数。
(※いずれも末尾6文字がSに含まれていたら0通り)
となる。
前の結果を利用して、次の結果を導き出すといえばDPですね!
ということで、末尾にstrを取るi文字の通り数をdp(str,i)と表すことにします。
上の分をdp(str,i)の形式で書き直すと、
dp("aaaaaa",7) = dp("aaaaaa",6) + dp("baaaaa",6)
dp("aaaaab",7) = dp("aaaaaa",6) + dp("baaaaa",6)
dp("aaaaba",7) = dp("aaaaab",6) + dp("baaaab",6)
(略)
dp("bbbbba",7) = dp("abbbbb",6) + dp("bbbbbb",6)
dp("bbbbbb",7) = dp("abbbbb",6) + dp("bbbbbb",6)
(※いずれも末尾6文字がSに含まれていたら0通り)
となる。
式を眺めてみるといずれも dp(Z,i) = dp(X,i-1) + dp(Y,i-1) ですね。
変形すると dp(Z,i) = 1*dp(X,i-1) + 1*dp(Y,i-1)
さらに変形すると dp(Z,i) = 0*dp(U,i-1) + … + 1*dp(X,i-1) + 0*dp(V,i-1) + … 1*dp(Y,i-1) + 0*dp(W,i-1) + …
となります。
この形と言えば、そう行列です!
ということで行列の形式で書き直すと、
$DP[i] = A・DP[i-1]$
$DP[i]$: dp("aaaaaa",i)~dp("bbbbbb",i)が行方向に並んでいるもの
$A$ : 1か0が要素の64*64の行列
です。
$DP[i+1] = A・DP[i] = A・A・DP[i-1]$ となるので、
以後も同様に考えると、$DP[N] = A^{N-6}・DP[6]$ として求めることが出来そうです!
ということで、まずは$A$を求めます!
# Aを求める
A = [[0 for _ in range(2**6)] for _ in range(2**6)]
for i in range(2**6):
str_i = str(bin(i))[2:] # bin(i)で2進数にして、strで文字列形式にして、[2:]で頭の"0b"と入ってしまうのを削る
str_i = str_i.zfill(min(6,N)) # iを6桁の01の組み合わせにしたもの
if pattern[str_i] == 1:
A[i][i//2] = 1
A[i][2**5 + i//2] = 1
else:
# 末尾6文字がsなので、0のままにする
pass
次に $A^{N-6}$ と、その結果と $DP[6]$ の行列積を求めます。
が、$N$が大きいので単純に$N$回掛け合わせるとTLEしそうです…
ということで$N$の乗数を$2^n$の掛け合わせで考える(例:$A^{11} = A^8 * A^2 * A^1$)ことで高速化します!
まずは、計算に必要な行列の累乗や積を求める関数を作ります!
MOD_BY = 998244353
def get_matrix_product(A,B):
'''
行列Aと行列Bの積を求める
# cf : https://w3e.kanazawa-it.ac.jp/math/category/gyouretu/senkeidaisu/henkan-tex.cgi?target=/math/category/gyouretu/senkeidaisu/gyouretu-no-seki.html
'''
l = len(A) # 行列Aの行数
m = len(A[0]) # 行列Aの列数 (= 行列Bの行数)
n = len(B[0]) # 行列Bの列数
if m != len(B):
# 行列Aの列数と行列Bの行数が不一致のため、行列積を計算できない
return -1
else:
C = [[0 for _ in range(n)] for _ in range(l)]
for i in range(l):
for j in range(n):
for k in range(m):
C[i][j] += A[i][k] * B[k][j]
C[i][j] %= MOD_BY
return C
def get_power_matrix(A,n):
'''
A**nの行列を求める
'''
if len(A) != len(A[0]):
# 正方行列でなければ、A**Nは計算不可
return 0
size = len(A) # Aの行数 (= Aの列数)
if n & (1<<0):
# A ** N は A**1を含むので、Aで仮置き
ans = [[A[i][j] for j in range(size)] for i in range(size)]
else:
# A ** N は A**1を含まないので、単位行列で仮置き
ans = [[0 for _ in range(size)] for _ in range(size)]
for i in range(size):
ans[i][i] = 1
# A**2,A**4,A**8 ... を求めつつ,nを表すのに必要だったら掛け合わせる
powered_A = A
for i in range(1,64):
powered_A = get_matrix_product(powered_A,powered_A) # A ** (i+1)
if n & (1<<i):
# A ** N は A**(i+1)を含むので、ansに掛け合わせる
ans = get_matrix_product(powered_A,ans)
return ans
準備が出来たので、$DP[N] = A^{N-6}・DP[6]$ を求めます!
# 行列累乗する
A_N = get_power_matrix(A,N-6) # A ** (N-6)。上の部分で長さNのうち、6文字は消費しているので削る
dp_6 = [[pattern[str(bin(i))[2:].zfill(6)]] for i in range(2**6)] # 文字列長さ6までのdpの結果(dp[6])。A **(N-6) と掛けるために列数1、行数2**6の形。
dp_N = get_matrix_product(A_N,dp_6) # A**(N-6) ・ f(str,6) して dp(str,N)を求める
答えは長さNの文字列の通り数を998244353で割ったものなので、求めて出力します。
# dp[N]の結果を足し合わせて、答えを求めて出力する
ans = 0
for i in range(2**6):
ans += dp_N[i][0]
ans %= MOD_BY
print(ans)
全体を繋げたコードは下記
N,M = map(int, input().split())
S = set()
for _ in range(M):
s = str(input())
s = s.replace('a', '0').replace('b', '1') # ←追加行。扱いやすいよう置換
S.add(s)
pattern = dict() # keyの文字列を取れる通り数
for i in range(2 ** min(6,N)):
str_i = str(bin(i))[2:] # bin(i)で2進数にして、strで文字列形式にして、[2:]で頭の"0b"と入ってしまうのを削る
str_i = str_i.zfill(min(6,N)) # iを6桁の01の組み合わせにしたもの
for s in S:
if str_i.count(s) != 0:
pattern[str_i] = 0
break
else:
pattern[str_i] = 1
if N <= 6:
ans = 0
for k,v in pattern.items():
ans += v
print(ans)
exit()
# Aを求める
A = [[0 for _ in range(2**6)] for _ in range(2**6)]
for i in range(2**6):
str_i = str(bin(i))[2:] # bin(i)で2進数にして、strで文字列形式にして、[2:]で頭の"0b"と入ってしまうのを削る
str_i = str_i.zfill(min(6,N)) # iを6桁(or N桁のうち少ない方の桁数)の01の組み合わせにしたもの
if pattern[str_i] == 1:
A[i][i//2] = 1
A[i][2**5 + i//2] = 1
else:
# 末尾6文字がsなので、0のままにする
pass
MOD_BY = 998244353
def get_matrix_product(A,B):
'''
行列Aと行列Bの積を求める
# cf : https://w3e.kanazawa-it.ac.jp/math/category/gyouretu/senkeidaisu/henkan-tex.cgi?target=/math/category/gyouretu/senkeidaisu/gyouretu-no-seki.html
'''
l = len(A) # 行列Aの行数
m = len(A[0]) # 行列Aの列数 (= 行列Bの行数)
n = len(B[0]) # 行列Bの列数
if m != len(B):
# 行列Aの列数と行列Bの行数が不一致のため、行列積を計算できない
return -1
else:
C = [[0 for _ in range(n)] for _ in range(l)]
for i in range(l):
for j in range(n):
for k in range(m):
C[i][j] += A[i][k] * B[k][j]
C[i][j] %= MOD_BY
return C
def get_power_matrix(A,n):
'''
A**nの行列を求める
'''
if len(A) != len(A[0]):
# 正方行列でなければ、A**Nは計算不可
return 0
size = len(A) # Aの行数 (= Aの列数)
if n & (1<<0):
# A ** N は A**1を含むので、Aで仮置き
ans = [[A[i][j] for j in range(size)] for i in range(size)]
else:
# A ** N は A**1を含まないので、単位行列で仮置き
ans = [[0 for _ in range(size)] for _ in range(size)]
for i in range(size):
ans[i][i] = 1
# A**2,A**4,A**8 ... を求めつつ,nを表すのに必要だったら掛け合わせる
powered_A = A
for i in range(1,64):
powered_A = get_matrix_product(powered_A,powered_A) # A ** (i+1)
if n & (1<<i):
# A ** N は A**(i+1)を含むので、ansに掛け合わせる
ans = get_matrix_product(powered_A,ans)
return ans
# 行列累乗する
A_N = get_power_matrix(A,N-6) # A ** (N-6)。上の部分で長さNのうち、6文字は消費しているので削る
dp_6 = [[pattern[str(bin(i))[2:].zfill(6)]] for i in range(2**6)] # 文字列長さ6までのdpの結果(dp[6])。A **(N-6) と掛けるために列数1、行数2**6の形。
dp_N = get_matrix_product(A_N,dp_6) # A**(N-6) ・ f(str,6) して dp(str,N)を求める
# dp[N]の結果を足し合わせて、答えを求めて出力する
ans = 0
for i in range(2**6):
ans += dp_N[i][0]
ans %= MOD_BY
print(ans)