2
0

More than 1 year has passed since last update.

お題は不問!Qiita Engineer Festa 2023で記事投稿!

ABC305G-「Banned Substrings」(行列累乗)、完全に理解した!!

Last updated at Posted at 2023-06-21

緑色中くらい(rating:934)だけど、青色問題解けたよー!! の喜びを分かち合いたい&数週間後の自分に向けてのメモ。

TL;DR

遷移の回数がすごい多い( $ 10^{18} $とか)DPを解きたい場合、

  1. DPの遷移式を行列形式にする
  2. 1.で求めた行列を遷移回数分累乗してから、行列積求める

で解ける!

問題

公式解説

…こちとら緑の中くらいだぞ!! わかるか、ちくしょー!! となりました。
冒頭の

1文字ずつ文字を伸ばしていくことを考えると、末尾のたかだか L=6 文字を状態としてもつ DP を考えることができます。

だけ何とか理解。「"行列累乗"って何?」というレベル。ここからいろいろ調べてACできた。

解説

入力読み込み

ひとまず入力を読み込む

入力読み込み
N,M = map(int, input().split())
S = set()
for _ in range(M):
    s = str(input())
    S.add(s)

6文字以下の場合を考えてみる

次に、ある文字列TがSに含まれる文字列sを部分文字列として含まない通り数を考える。
sは制約により長さ6以下なので、ひとまず長さ6文字での全通りのパターンに対して考えてみる。
sを含んでいるか否かは.count(s)が1以上なら含んでる、0なら含んでないと調べられる。
Tに含まれるのはa,bのみで全パターンを調べる必要があるので、ここはbit全探索を使う。

まずはbit全探索しやすいように入力を修正。

入力読み込み修正
N,M = map(int, input().split())
S = set()
for _ in range(M):
    s = str(input())
    s = s.replace('a', '0').replace('b', '1')   # ←追加行。扱いやすいようaを0、bを1に置換する。
    S.add(s)

そして、bit全探索する。

bit全探索
pattern = dict()                    # keyの文字列を取れる通り数

for i in range(2 ** min(6,N)):
    str_i = str(bin(i))[2:]         # bin(i)で2進数にして、strで文字列形式にして、[2:]で頭の"0b"を削る
    str_i = str_i.zfill(min(6,N))   # iを6桁(or N桁のうち少ない方の桁数)の01の組み合わせにしたもの

    for s in S:
        if str_i.count(s) != 0:
            pattern[str_i] = 0
            break
    else:
        pattern[str_i] = 1

$N <= 6$の場合は、各文字列での通り数を合計すれば良いので、こうなる。

N <= 6の場合
if N <= 6:
    ans = 0
    for k,v in pattern.items():
        ans += v
    print(ans)
    exit()

これで入力例1はクリアー♪

7文字以上の場合を考えてみる

7文字で末尾がaaaaaaとなり、かつsが含まれない通り数を考えてみる。
6文字以下の場合で考えた結果を利用すると、

  • baaaaaのときにaが入る場合(b aaaaaa)
  • aaaaaaのときにaが入る場合(a aaaaaa)

の和となる。但し、もちろんaaaaaaがSに含まれていたら0通り。
同様に7文字で末尾がaaaaabとなる通り数は、末尾がaaaaaaの通り数 + 末尾がbaaaaaの通り数。
同様に7文字で末尾がaaaabaとなる通り数は、末尾がaaaaabの通り数 + 末尾がbaaaabの通り数。
 (略)
同様に7文字で末尾がbbbbbaとなる通り数は、末尾がabbbbbの通り数 + 末尾がbbbbbbの通り数。
同様に7文字で末尾がbbbbbbとなる通り数は、末尾がabbbbbの通り数 + 末尾がbbbbbbの通り数。
(※いずれも末尾6文字がSに含まれていたら0通り)
となる。

前の結果を利用して、次の結果を導き出すといえばDPですね!
ということで、末尾にstrを取るi文字の通り数をdp(str,i)と表すことにします。
上の分をdp(str,i)の形式で書き直すと、
 dp("aaaaaa",7) = dp("aaaaaa",6) + dp("baaaaa",6)
 dp("aaaaab",7) = dp("aaaaaa",6) + dp("baaaaa",6)
 dp("aaaaba",7) = dp("aaaaab",6) + dp("baaaab",6)
  (略)
 dp("bbbbba",7) = dp("abbbbb",6) + dp("bbbbbb",6)
 dp("bbbbbb",7) = dp("abbbbb",6) + dp("bbbbbb",6)
(※いずれも末尾6文字がSに含まれていたら0通り)
となる。

式を眺めてみるといずれも dp(Z,i) = dp(X,i-1) + dp(Y,i-1) ですね。
変形すると    dp(Z,i) = 1*dp(X,i-1) + 1*dp(Y,i-1)
さらに変形すると dp(Z,i) = 0*dp(U,i-1) + … + 1*dp(X,i-1) + 0*dp(V,i-1) + … 1*dp(Y,i-1) + 0*dp(W,i-1) + …
となります。

この形と言えば、そう行列です!
ということで行列の形式で書き直すと、
 $DP[i] = A・DP[i-1]$
   $DP[i]$: dp("aaaaaa",i)~dp("bbbbbb",i)が行方向に並んでいるもの
   $A$ : 1か0が要素の64*64の行列
です。

$DP[i+1] = A・DP[i] = A・A・DP[i-1]$ となるので、
以後も同様に考えると、$DP[N] = A^{N-6}・DP[6]$ として求めることが出来そうです!
ということで、まずは$A$を求めます!

Aを求める
# Aを求める
A = [[0 for _ in range(2**6)] for _ in range(2**6)]
for i in range(2**6):
    str_i = str(bin(i))[2:]         # bin(i)で2進数にして、strで文字列形式にして、[2:]で頭の"0b"と入ってしまうのを削る
    str_i = str_i.zfill(min(6,N))   # iを6桁の01の組み合わせにしたもの

    if pattern[str_i] == 1:
        A[i][i//2] = 1
        A[i][2**5 + i//2] = 1
    else:
        # 末尾6文字がsなので、0のままにする
        pass

次に $A^{N-6}$ と、その結果と $DP[6]$ の行列積を求めます。
が、$N$が大きいので単純に$N$回掛け合わせるとTLEしそうです…
ということで$N$の乗数を$2^n$の掛け合わせで考える(例:$A^{11} = A^8 * A^2 * A^1$)ことで高速化します!
まずは、計算に必要な行列の累乗や積を求める関数を作ります!

行列の累乗や積を求める
MOD_BY = 998244353

def get_matrix_product(A,B):
    '''
    行列Aと行列Bの積を求める
    # cf : https://w3e.kanazawa-it.ac.jp/math/category/gyouretu/senkeidaisu/henkan-tex.cgi?target=/math/category/gyouretu/senkeidaisu/gyouretu-no-seki.html
    '''
    l = len(A)        # 行列Aの行数
    m = len(A[0])     # 行列Aの列数 (= 行列Bの行数)
    n = len(B[0])     # 行列Bの列数

    if m != len(B):
        # 行列Aの列数と行列Bの行数が不一致のため、行列積を計算できない
        return -1
    else:
        C = [[0 for _ in range(n)] for _ in range(l)]

        for i in range(l):
            for j in range(n):
                for k in range(m):
                    C[i][j] += A[i][k] * B[k][j]
                    C[i][j] %= MOD_BY

        return C

def get_power_matrix(A,n):
    '''
    A**nの行列を求める
    '''
    if len(A) != len(A[0]):
        # 正方行列でなければ、A**Nは計算不可
        return 0

    size = len(A)   # Aの行数 (= Aの列数)

    if n & (1<<0):
        # A ** N は A**1を含むので、Aで仮置き
        ans = [[A[i][j] for j in range(size)] for i in range(size)]
    else:
        # A ** N は A**1を含まないので、単位行列で仮置き
        ans = [[0 for _ in range(size)] for _ in range(size)]
        for i in range(size):
            ans[i][i] = 1

    # A**2,A**4,A**8 ... を求めつつ,nを表すのに必要だったら掛け合わせる
    powered_A = A

    for i in range(1,64):
        powered_A = get_matrix_product(powered_A,powered_A)     # A ** (i+1)

        if n & (1<<i):
            # A ** N は A**(i+1)を含むので、ansに掛け合わせる
            ans = get_matrix_product(powered_A,ans)

    return ans

準備が出来たので、$DP[N] = A^{N-6}・DP[6]$ を求めます!

行列累乗
# 行列累乗する
A_N = get_power_matrix(A,N-6)                                       # A ** (N-6)。上の部分で長さNのうち、6文字は消費しているので削る
dp_6 = [[pattern[str(bin(i))[2:].zfill(6)]] for i in range(2**6)]   # 文字列長さ6までのdpの結果(dp[6])。A **(N-6) と掛けるために列数1、行数2**6の形。
dp_N = get_matrix_product(A_N,dp_6)                                 # A**(N-6) ・ f(str,6) して dp(str,N)を求める

答えは長さNの文字列の通り数を998244353で割ったものなので、求めて出力します。

答えを求めて出力
# dp[N]の結果を足し合わせて、答えを求めて出力する
ans = 0
for i in range(2**6):
    ans += dp_N[i][0]
    ans %= MOD_BY

print(ans)

全体を繋げたコードは下記

コード全体
N,M = map(int, input().split())
S = set()

for _ in range(M):
    s = str(input())
    s = s.replace('a', '0').replace('b', '1')   # ←追加行。扱いやすいよう置換
    S.add(s)

pattern = dict()                    # keyの文字列を取れる通り数

for i in range(2 ** min(6,N)):
    str_i = str(bin(i))[2:]         # bin(i)で2進数にして、strで文字列形式にして、[2:]で頭の"0b"と入ってしまうのを削る
    str_i = str_i.zfill(min(6,N))   # iを6桁の01の組み合わせにしたもの

    for s in S:
        if str_i.count(s) != 0:
            pattern[str_i] = 0
            break
    else:
        pattern[str_i] = 1

if N <= 6:
    ans = 0
    for k,v in pattern.items():
        ans += v
    print(ans)
    exit()

# Aを求める
A = [[0 for _ in range(2**6)] for _ in range(2**6)]
for i in range(2**6):
    str_i = str(bin(i))[2:]         # bin(i)で2進数にして、strで文字列形式にして、[2:]で頭の"0b"と入ってしまうのを削る
    str_i = str_i.zfill(min(6,N))   # iを6桁(or N桁のうち少ない方の桁数)の01の組み合わせにしたもの

    if pattern[str_i] == 1:
        A[i][i//2] = 1
        A[i][2**5 + i//2] = 1
    else:
        # 末尾6文字がsなので、0のままにする
        pass

MOD_BY = 998244353

def get_matrix_product(A,B):
    '''
    行列Aと行列Bの積を求める
    # cf : https://w3e.kanazawa-it.ac.jp/math/category/gyouretu/senkeidaisu/henkan-tex.cgi?target=/math/category/gyouretu/senkeidaisu/gyouretu-no-seki.html
    '''
    l = len(A)        # 行列Aの行数
    m = len(A[0])     # 行列Aの列数 (= 行列Bの行数)
    n = len(B[0])     # 行列Bの列数

    if m != len(B):
        # 行列Aの列数と行列Bの行数が不一致のため、行列積を計算できない
        return -1
    else:
        C = [[0 for _ in range(n)] for _ in range(l)]

        for i in range(l):
            for j in range(n):
                for k in range(m):
                    C[i][j] += A[i][k] * B[k][j]
                    C[i][j] %= MOD_BY

        return C

def get_power_matrix(A,n):
    '''
    A**nの行列を求める
    '''
    if len(A) != len(A[0]):
        # 正方行列でなければ、A**Nは計算不可
        return 0

    size = len(A)   # Aの行数 (= Aの列数)

    if n & (1<<0):
        # A ** N は A**1を含むので、Aで仮置き
        ans = [[A[i][j] for j in range(size)] for i in range(size)]
    else:
        # A ** N は A**1を含まないので、単位行列で仮置き
        ans = [[0 for _ in range(size)] for _ in range(size)]
        for i in range(size):
            ans[i][i] = 1

    # A**2,A**4,A**8 ... を求めつつ,nを表すのに必要だったら掛け合わせる
    powered_A = A

    for i in range(1,64):
        powered_A = get_matrix_product(powered_A,powered_A)     # A ** (i+1)

        if n & (1<<i):
            # A ** N は A**(i+1)を含むので、ansに掛け合わせる
            ans = get_matrix_product(powered_A,ans)

    return ans

# 行列累乗する
A_N = get_power_matrix(A,N-6)                                       # A ** (N-6)。上の部分で長さNのうち、6文字は消費しているので削る
dp_6 = [[pattern[str(bin(i))[2:].zfill(6)]] for i in range(2**6)]   # 文字列長さ6までのdpの結果(dp[6])。A **(N-6) と掛けるために列数1、行数2**6の形。
dp_N = get_matrix_product(A_N,dp_6)                                 # A**(N-6) ・ f(str,6) して dp(str,N)を求める
# dp[N]の結果を足し合わせて、答えを求めて出力する
ans = 0
for i in range(2**6):
    ans += dp_N[i][0]
    ans %= MOD_BY

print(ans)

ということで上記コードを提出すると
image.png
見事ACです!! rating934だけど青色問題解けた~!!

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0