ABC398 F - ABCBA
https://atcoder.jp/contests/abc398/tasks/abc398_f
先日のF問題が解けなくて悔しかったので勉強してACしてきました。
コンテスト中&コンテスト後の自分の動き
Dまで解いてEを捨て、35分ほど残った時間でFに挑みました。文字列Sの後ろにある回文を求めればよいことはすぐわかりましたが、回文の判定には時間がかかるので普通にやっても解けません。あれこれ悩みましたが結局どうしようもありませんでした。
コンテスト後の皆さんのツイートからなんとか「ローリングハッシュ」なる単語を拾ってきました。よく聞く名前なのでこの機会に覚えるべきだと思い、勉強しました。
https://qiita.com/hirominn/items/80464ee381c8d400725f
この考え方を利用してなんとかACすることができました。
やってみると案外簡単で、どうして今まで勉強してこなかったのかと後悔したぐらいです。ABC の E 問題以降を解けるようにするためにはこれも技としてしっかり身につけて、必要な場面で使えるようにしておかないといけないと思いました。
考察
S を逆に並び替えた文字列を T とします。
- S = [非回文1] + [回文]
- T = [回文] + [非回文1の逆]
という形になり、答えは
- [非回文1] + [回文] + [非回文1の逆]
となります。これを求めるために、S を後ろからみたときに回文となる文字列のうち最も長いものを探します。公式解説に書かれている通りです。
i = 1, 2, ..., N について、「T の頭から i 文字目まで」と「S の後ろから i 文字目まで」を比較していきます。普通に比較すると計算量オーバーになりますから、ここでハッシュ値を利用します。i が増えるにつれて T は後ろへ後ろへと伸び、S は前へ前へと伸びていきます。「i 文字目までの文字列のハッシュ値」から「i+1 文字目までの文字列のハッシュ値」は後述するように素早く求めることができます。
T[:k] のハッシュ値から T[:k+1] のハッシュ値を求める
\begin{align}
hash(k) &= T_1 \cdot b^{k-1} &+ T_2 \cdot b^{k-2} &+ ...... &+ T_k \cdot b^0\\
hash(k+1) &= T_1 \cdot b^{k} &+ T_2 \cdot b^{k-1} &+ ...... &+ T_k \cdot b^1 &+ T_{k+1} \cdot b^0
\end{align}
より
\begin{align}
hash(k+1) = hash(k) \cdot b + T_{k+1}
\end{align}
です。先に引用したページ
https://qiita.com/hirominn/items/80464ee381c8d400725f
では長さ M の文字列を 1 つずつずらしながら見ていましたが、今回はそれよりも単純です。
S[-k:] のハッシュ値から S[-(k+1):] のハッシュ値を求める
\begin{align}
hash(k) &= &S_k \cdot b^{k-1} &+ S_{k-1} \cdot b^{k-2} &+ ...... &+ S_1 \cdot b^0\\
hash(k+1) &= S_{k+1} \cdot b^{k} &+ S_{k} \cdot b^{k-1} &+ S_{k-1} \cdot b^{k-2} &+ ...... &+ S_1 \cdot b^0
\end{align}
より
\begin{align}
hash(k+1) = hash(k) + S_{k+1} \cdot b^{k}
\end{align}
です。こちらはさらに簡単ですね。
実装
前から i 文字のハッシュ値を求める部分と後ろから i 文字のハッシュ値を求める部分をそれぞれ関数にまとめました。
また、ハッシュ値の衝突が起こりにくいように念のため基数を変えて 2 回計算しています。
S = input()
T = S[::-1]
N = len(S)
S, T = "." + S, "." + T # 1-index にする
MOD = 998244353
BASE1 = 100019
BASE2 = 15
# 前から i 文字目までの部分文字列のハッシュ値を hash[i] に格納して返す
def hash_from_head(str, base, mod):
n = len(str) - 1 # str は 1-indexなので
hash = [0 for _ in range(n+1)]
for i in range(1, n+1):
hash[i] = hash[i-1] * base + ord(str[i])
hash[i] %= mod
return hash
# 後ろから i 文字目までの部分文字列のハッシュ値を hash[i] に格納して返す
def hash_from_tail(str, base, mod):
n = len(str) - 1 # str は 1-indexなので
hash = [0 for _ in range(n+1)]
for i in range(1, N+1):
hash[i] = hash[i-1] + ord(str[-i]) * pow(base, i-1, mod)
hash[i] %= mod
return hash
# T[:i] は後ろへ後ろへと伸びていく。ローリングハッシュの要領で次のハッシュ値を求める。
# 念のため2種類のハッシュ値で調べる。
hashs_head_T1 = hash_from_head(T, BASE1, MOD)
hashs_head_T2 = hash_from_head(T, BASE2, MOD)
# S[i:] は前へ前へと伸びていく。ローリングハッシュの要領で次のハッシュ値を求める。
hashs_tail_S1 = hash_from_tail(S, BASE1, MOD)
hashs_tail_S2 = hash_from_tail(S, BASE2, MOD)
# 最長の回文を探す
palindrome_length = 1
for i in range(1, N+1):
if hashs_head_T1[i] == hashs_tail_S1[i] and hashs_head_T2[i] == hashs_tail_S2[i]:
palindrome_length = i
ans = S[1:] + T[palindrome_length + 1:]
print(ans)