問題
既存投稿一覧ページへのリンク
Before
AtCoder Beginner Contest 378_D 問題
Next
AtCoder Beginner Contest 379_C 問題
その他・記事まとめ
解法手順1
問題の概要
長さ $N$ の非負整数列 $A = (A_1, A_2, ..., A_N)$ と正整数 $M$ が与えられる。
全ての区間 $[l, r]$ ($1 \leq l \leq r \leq N$) について、 $\left(\sum_{i=l}^{r} A_i\right) mod M$の総和を求める問題。
解法の全体像
- 累積和を使い、区間和を効率的に計算する。
- 区間和の $M$ で割った余りを高速に数え上げるために、累積和のmod $M$ の値を管理する。
- 全ての区間について、区間和 mod $M$ の総和を効率的に計算する。
ステップ1: 配列Aの各要素をMで割った余りに変換
for i in range(N):
A[i] %= M
ステップ2: 累積和のmod Mを計算し、リストに格納
cur = A[0]
allsum = cur
mod_list = [cur]
for j in range(1, N):
cur = (cur + A[j]) % M
allsum += cur
mod_list.append(cur)
-
cur
は現在位置までの累積和のmod $M$。 -
allsum
は累積和mod $M$ の総和(全ての区間 [1, r] の和 mod $M$ の総和)。
ステップ3: mod_listの累積和を作成
区間の始点をずらしていく際に、累積和を効率よく計算するために、mod_list
の累積和リスト cummod
を作る。
cummod = [0]
for m in mod_list:
cummod.append(cummod[-1] + m)
ステップ4: 全区間の和の総和を初期値とする
まず、全ての区間が [1, r] となる場合(始点が1固定)の和(allsum
)を ans
の初期値とする。
ans = allsum
ステップ5: 始点を1からN-1までずらし、区間和を計算
始点を1つずつ右にずらし、全区間の和を計算する。
SortedList
を使って、累積和mod $M$ の値の出現数を高速に管理する。
from sortedcontainers import SortedList
mods = SortedList(mod_list)
for i in range(N-1):
cursum = allsum - cummod[i]
cursum -= mod_list[i] * (N - i)
mods.discard(mod_list[i])
under_ct = mods.bisect_left(mod_list[i])
cursum += under_ct * M
ans += cursum
-
cursum = allsum - cummod[i]
区間の始点を $i+1$ にしたときの累積和の総和。 -
cursum -= mod_list[i] * (N - i)
始点をずらすことで、mod_list[i] の値が (N-i) 回分減る。 -
mods.discard(mod_list[i])
今回の始点に対応する値をリストから削除。 -
under_ct = mods.bisect_left(mod_list[i])
mod_list[i] 未満の値がいくつあるか(二分探索で高速取得)。 -
cursum += under_ct * M
mod_list[i] 未満の値については、mod $M$ の性質上、Mを足すことで正しい値になる。 -
ans += cursum
各始点ごとに求めた区間和を加算。
ACコード1(※Loggerを消さないとTLE)
ac.py
import logging
from sortedcontainers import SortedList
def setup_logger(debug_mode):
logger = logging.getLogger(__name__)
if not logger.handlers:
logger.setLevel(logging.DEBUG if debug_mode else logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('program_trace.log', encoding="utf8")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.debug(f"ロガーのセットアップが完了しました。デバッグモード: {debug_mode}")
return logger
def io_func():
# 入力値取得
N_and_M = input().split()
N = int(N_and_M[0])
M = int(N_and_M[1])
A = list(map(int, input().split()))
return N, M, A
def solve(N, M, A, logger):
# 各要素をMで割った余りに変換
for i in range(N):
A[i] %= M
logger.debug(f"A[{i}]をMで割った余りに変換: {A[i]}")
cur = A[0]
allsum = cur
mods = SortedList()
mods.add(cur)
mod_list = [cur]
logger.debug(f"初期値 cur={cur}, allsum={allsum}, mods={list(mods)}, mod_list={mod_list}")
# 累積和のmod Mを計算
for j in range(1, N):
cur = (cur + A[j]) % M
allsum += cur
mods.add(cur)
mod_list.append(cur)
logger.debug(f"{j}番目: cur={cur}, allsum={allsum}, mods={list(mods)}, mod_list={mod_list}")
# mod_listの累積和を作成
cummod = [0]
for m in mod_list:
cummod.append(cummod[-1] + m)
logger.debug(f"mod_list累積和: {cummod}")
ans = allsum
logger.debug(f"初期ans={ans}")
for i in range(N-1):
cursum = allsum - cummod[i]
cursum -= mod_list[i] * (N - i)
mods.discard(mod_list[i])
under_ct = mods.bisect_left(mod_list[i])
cursum += under_ct * M
ans += cursum
logger.debug(
f"{i}番目: cursum={cursum}, allsum={allsum}, cummod[i]={cummod[i]}, "
f"mod_list[i]={mod_list[i]}, N-i={N-i}, under_ct={under_ct}, ans={ans}, mods={list(mods)}"
)
print(ans)
def main():
debug_mode = False # 必要に応じてTrueに
logger = setup_logger(debug_mode)
N, M, A = io_func()
solve(N, M, A, logger)
if __name__ == "__main__":
main()