20
8

More than 1 year has passed since last update.

【AtCoder解説】PythonでABC253のA,B,C,D,E問題を制する!

Posted at

ABC253A,B,C,D,E問題を、Python3でなるべく丁寧に解説していきます。

ただ解けるだけの方法ではなく、次の3つのポイントを満たす解法を解説することを目指しています。

  • シンプル:余計なことを考えずに済む
  • 実装が楽:ミスやバグが減ってうれしい
  • 時間がかからない:パフォが上がって、後の問題に残せる時間が増える

ご質問・ご指摘はコメントツイッターマシュマロ、Discordサーバーまでお気軽にどうぞ!

Twitter: u2dayo
マシュマロ: https://marshmallow-qa.com/u2dayo
ほしいものリスト: https://www.amazon.jp/hz/wishlist/ls/2T9IQ8IK9ID19?ref_=wl_share
Discordサーバー(質問や記事の感想・リクエストなどどうぞ!) : https://discord.gg/jZ8pkPRRMT
よかったらLGTM拡散していただけると喜びます!

目次

ABC253 まとめ
A問題『Median?』
B問題『Distance Between Tokens』
C問題『Max - Min Query』
D問題『FizzBuzz Sum Hard』
E問題『Distance Sequence』

アプリ AtCoderFacts を開発しています

コンテストの統計データを見られるアプリ『AtCoderFacts』を作りました。
現在のところ、次の3つのデータを見ることができます。

  • レート別問題正解率
  • パフォーマンス目安
  • 早解きで上昇するパフォーマンス

今後も機能を追加していく予定です。使ってくれると喜びます。

ABC253 まとめ

全提出人数: 8944人

パフォーマンス

パフォ AC 点数 時間 順位(Rated内)
200 AB------ 300 42分 6312(6063)位
400 AB------ 300 12分 5172(4923)位
600 AB-D---- 700 109分 4180(3941)位
800 ABCD---- 1000 88分 3221(2984)位
1000 ABCD---- 1000 35分 2349(2112)位
1200 ABCDE--- 1500 99分 1665(1438)位
1400 ABCDE--- 1500 65分 1170(946)位
1600 ABCDE--- 1500 38分 806(590)位
1800 ABCDEF-- 2000 94分 540(343)位
2000 ABCDEF-- 2000 52分 345(184)位
2200 ABCDEFG- 2600 105分 214(91)位
2400 ABCDEFG- 2600 85分 138(43)位

色別の正解率

人数 A B C D E F G Ex
3319 94.5 % 79.2 % 23.1 % 22.0 % 2.5 % 0.1 % 0.0 % 0.0 %
1507 98.9 % 97.7 % 68.2 % 66.2 % 12.5 % 0.2 % 0.6 % 0.1 %
1221 97.9 % 97.4 % 85.9 % 89.0 % 48.6 % 3.8 % 1.1 % 0.1 %
663 98.0 % 97.6 % 93.5 % 95.6 % 88.4 % 21.0 % 4.1 % 0.0 %
412 98.3 % 97.6 % 96.8 % 97.3 % 96.1 % 64.6 % 30.6 % 1.7 %
183 91.3 % 90.7 % 90.7 % 90.2 % 89.1 % 73.2 % 59.0 % 9.3 %
40 90.0 % 90.0 % 87.5 % 92.5 % 85.0 % 80.0 % 87.5 % 30.0 %
29 100.0 % 96.5 % 100.0 % 96.5 % 100.0 % 96.5 % 100.0 % 62.1 %

表示レート、灰に初参加者は含めず

A問題『Median?』

問題ページA - Median?
コーダー正解率:94.5 %
コーダー正解率:98.9 %
コーダー正解率:97.9 %

入力

$a,b,c$ : 整数

考察

中央値と $b$ が同じか判定したいので、$a,b,c$ の中央値を求められればいいです。

$3$ つある整数の中央値を求める方法は色々あります。今回は次の三つの方法を紹介します。

  • $a,b,c$ をソートして、$3$ つのうち $2$ 番目が中央値
  • $a+b+c$ から最大値と最小値を引く
  • statisticsモジュールのmedian関数を使う

ソートする方法

中央値とは、小さい順または大きい順に数を並べたとき、ちょうど真ん中にある数のことです。(※)

$a,b,c$ を小さい順にソートして、$2$ 番目に小さい値を取ればそれが中央値です。

※問題によって要素数が偶数のときの中央値の定義が異なるので、定義が書いてある場合はよく確認しましょう

最大値と最小値を引く方法

$3$ つの整数の最大値を $p_{max}$、最小値を $p_{min}$、中央値を $p_{med}$ とおきます。求めたいものは中央値の $p_{med}$ です。

$a+b+c=p_{min}+p_{med}+p_{max}$ です。つまり $a+b+c-p_{min}-p_{max}=p_{med}$ です。

最大値と最小値はmax関数とmin関数で簡単に求められます。

median関数を使う

statisticsモジュールの中央値を求めるmedian関数を使うだけです。

なおmedian関数は要素数 $N$ が偶数のとき、$\dfrac{N}{2}$ 番目と、$\dfrac{N}{2}+1$ 番目の平均を返します。$\dfrac{N}{2}$ 番目を返すmedian_low関数、$\dfrac{N}{2}+1$ 番目を返す median_high関数もあります。

コード

ソートして2番目

a, b, c = map(int, input().split())
m = sorted((a, b, c))[1]  # 3個あるうちの2番目が中央値なので、[1]です
print("Yes" if b == m else "No")

最大値と最小値を引く

a, b, c = map(int, input().split())
m = a + b + c - min(a, b, c) - max(a, b, c)
print("Yes" if b == m else "No")

statisticsモジュールのmedian関数

from statistics import median

a, b, c = map(int, input().split())
m = median((a, b, c))
print("Yes" if b == m else "No")

B問題『Distance Between Tokens』

問題ページB - Distance Between Tokens
コーダー正解率:79.2 %
コーダー正解率:97.7 %
コーダー正解率:97.4 %

入力

$H,W$ : $H$ 行 $W$ 列のマス目がある
$S_i$ : $i$ 行目のマス目の状態を表す、長さ $W$ の文字列

考察

$2$ つの o が左右に何マス分離れていて、上下に何マス分離れているか求め、それらを足すと答えになります。

$2$ つのoの位置は、二重ループで行と列を全探索して探せばわかります。

実装

見つけた点をリストに追加しておいて、最後に行の差の絶対値、列の差の絶対値を求めるのもいいです。

もうひとつの方法に、二つの変数dr, dcを $0$ で初期化し、点が見つかったら

dr = abs(dr-row)
dc = abs(dc-col)

とします。$1$ 個目のoでは、dr,dcは $0$ なので、$1$ 個目の点の座標にそのまま更新されます。$2$ 個目の o で、dr, dcは求めたい $2$ 点の差となります。あとは最後にans = dr + dc とすればいいです。タイプ量が少なくて楽なので、こちらの方法もおすすめです。

コード

見つけた点をリストに追加する方法

H, W = map(int, input().split())
S = [input() for _ in range(H)]
P = []
for row in range(H):
    for col in range(W):
        if S[row][col] == "o":
            P.append((row, col))
ar, ac = P[0]
br, bc = P[1]
ans = abs(ar - br) + abs(ac - bc)
print(ans)

点を見つけたら絶対値の差を更新する方法

H, W = map(int, input().split())
S = [input() for _ in range(H)]
ar, ac = 0, 0
for row in range(H):
    for col in range(W):
        if S[row][col] == "o":
            ar = abs(ar - row)
            ac = abs(ac - col)
print(ar + ac)

C問題『Max - Min Query』

問題ページC - Max - Min Query
コーダー正解率:23.1 %
コーダー正解率:68.2 %
コーダー正解率:85.9 %

入力

$Q$ : クエリの個数

$query_i$ は以下の $3$ 種類のいずれか

$1\ \ x$ : $S$ に $x$ を $1$ 個追加する。
$2\ \ x\ \ c$ : $S$ から $x$ を $\mathrm{min}(c,Sに含まれるxの個数)$ 個削除する。つまり、基本的には $S$ から $x$ を $c$ 個削除するが、$x$ が $c$ 個未満で足りないとき、個数はマイナスではなく $0$ 個になる。
$3$ : ( $S$ の最大値) $-$ ( $S$ の最小値) を出力する。

考察

順序付き多重集合というデータ構造を用いるのが簡単です。C++にはstd::multisetという標準ライブラリがありますが、Pythonにはありません。自分で実装しましょう、と言いたいところですが、かなり複雑で実装が大変なデータ構造ですので、誰かが公開しているライブラリを使うことをおすすめします。

順序付き多重集合は、要素が常に昇順に保たれた配列で、同一の要素を複数個含むことを許すデータ構造です。(同一の要素を複数個含むことを許さない場合は、順序付き集合になります)

操作は、要素の追加、要素の削除、$x$ 番目の要素のアクセス、二分探索を $O(\log{N})$ で行えます。

クエリ 1

  • $S$ に $x$ を $1$ 個追加する。

S.add(x)で終わりです。計算量はクエリ $1$ 回につき $O(\log{N})$ です。

クエリ2

  • $S$ から $x$ を $\mathrm{min}(c,Sに含まれるxの個数)$ 個削除する。

$min(c, Sに含まれるxの個数)$ 回、S.remove(x) を行います。

一見すると大量の要素を削除するクエリが何度も来ると計算量が大きくなってしまいそうです。しかしこの問題では、$S$ に追加された要素の数と同じ回数までしか、削除を行うことはできません。存在しないものを消すことはできないということです。

要素の追加クエリ $1$ は $1$ 回につき要素が $1$ 個増えるだけです。したがって、削除の回数は多くても $Q-1$ 回です。

計算量は全体で $O(Q\log{N})$ です。

クエリ3

  • ( $S$ の最大値) $-$ ( $S$ の最小値) を出力する。

S.max - S.min です。計算量はクエリ $1$ 回につき $O(\log{N})$ です。

実装

こちらのライブラリを参考に、筆者が自作したOrderedMultiList(Fenwick木・平方分割・二分探索を利用した順序付き多重集合)のライブラリを貼っておきます。

いくつかの問題が通ることは確認していますが、まだしっかりとテストを行っていないですし、ドキュメントもないので、あまり信用しないでください。

コード

from bisect import insort, bisect_left, bisect_right

__file = open(0)  # 入力の受け取りが重いので、高速なreadlineを使います
readline = __file.readline  # 400ms程度速くなります

class OrderedMultiList:
    class FenwickTree:
        def __init__(self, array):
            n = len(array)
            self.__container = [0] + array[:]
            self.__size = len(self.__container)
            self.__depth = n.bit_length()
            for i in range(n):
                j = i | (i + 1)
                if j < n:
                    self.__container[j + 1] += self.__container[i + 1]

        def add(self, i, x):
            i += 1
            while i < self.__size:
                self.__container[i] += x
                i += i & (-i)

        def sum(self, r):
            s = 0
            while r > 0:
                s += self.__container[r]
                r -= r & (-r)
            return s

        def upper_bound(self, s):
            w, r = 0, 0
            for i in reversed(range(self.__depth)):
                k = r + (1 << i)
                if k < self.__size and w + self.__container[k] <= s:
                    w += self.__container[k]
                    r += 1 << i
            return r

    __load = 1000

    def __init__(self):
        self.__lists = []
        self.__maxes = []
        self.__sizes_list = []
        self.__sizes_ft = self.FenwickTree([])
        self.__len = 0
        self.__block_count = 0

    def add(self, x):
        if self.__len > 0:
            li = bisect_left(self.__maxes, x)
            if li == len(self.__maxes):
                li -= 1
                self.__maxes[-1] = x
            _list = self.__lists[li]
            insort(_list, x)
            self.__sizes_list[li] += 1
            self.__sizes_ft.add(li, 1)
            if len(_list) >= 2 * self.__load:
                self.__split(li)
        else:
            self.__maxes.append(x)
            self.__lists.append([x])
            self.__sizes_list.append(1)
            self.__build_ft()
            self.__block_count += 1
        self.__len += 1

    def discard(self, x):
        if self.__len == 0:
            return False
        li = bisect_left(self.__maxes, x)
        if li == self.__block_count:
            return False
        __list = self.__lists[li]
        lli = bisect_left(__list, x)
        y = __list[lli]
        if x != y:
            return False
        del __list[lli]
        if not __list:
            del self.__lists[li]
            del self.__maxes[li]
            del self.__sizes_list[li]
            self.__block_count -= 1
            self.__len -= 1
            self.__build_ft()
            return True
        if x == self.__maxes[li]:
            self.__maxes[li] = __list[-1]
        self.__len -= 1
        self.__sizes_ft.add(li, -1)
        self.__sizes_list[li] -= 1
        if self.__block_count >= 2 and 2 * self.__sizes_list[li] <= self.__load:
            self.__marge(li)
        return True

    def contains(self, x):
        if self.__len == 0:
            return False
        li = bisect_left(self.__maxes, x)
        if li == self.__block_count:
            return False
        _list = self.__lists[li]
        lli = bisect_left(_list, x)
        y = _list[lli]
        return x == y

    def index(self, x):
        if self.__len == 0:
            return None
        li = bisect_left(self.__maxes, x)
        if li == self.__block_count:
            return None
        _list = self.__lists[li]
        lli = bisect_left(_list, x)
        y = _list[lli]
        if x != y:
            return None
        ft = self.__sizes_ft
        return ft.sum(li) + lli

    def index_right(self, x):
        if self.__len == 0:
            return None
        li = bisect_right(self.__maxes, x)
        if li == self.__block_count:
            y = self.__lists[-1][-1]
            return self.__pos(li - 1, self.__sizes_list[-1] - 1) if y == x else None
        _list = self.__lists[li]
        lli = bisect_right(_list, x) - 1
        if lli == -1:
            if li == 0:
                return None
            y = self.__lists[li - 1][-1]
            return self.__pos(li - 1, self.__sizes_list[li - 1] - 1) if y == x else None
        y = _list[lli]
        if x != y:
            return None
        return self.__pos(li, lli)

    def bisect_left(self, x):
        li = bisect_left(self.__maxes, x)
        if li == self.__block_count:
            return self.__pos(self.__block_count, 0)
        lli = bisect_left(self.__lists[li], x)
        return self.__pos(li, lli)

    def bisect_right(self, x):
        li = bisect_right(self.__maxes, x)
        if li == self.__block_count:
            return self.__pos(self.__block_count, 0)
        lli = bisect_right(self.__lists[li], x)
        return self.__pos(li, lli)

    def count(self, x):
        return self.bisect_right(x) - self.bisect_left(x)

    def le(self, x):
        if self.__len == 0:
            return None
        li = bisect_left(self.__maxes, x)
        if li == self.__block_count:
            return self.__lists[li - 1][-1]
        _list = self.__lists[li]
        lli = bisect_right(_list, x)
        if lli == 0:
            if li == 0:
                return None
            return self.__lists[li - 1][-1]
        return self.__lists[li][lli - 1]

    def lt(self, x):
        if self.__len == 0:
            return None
        li = bisect_left(self.__maxes, x)
        if li == self.__block_count:
            return self.__lists[li - 1][-1]
        _list = self.__lists[li]
        lli = bisect_left(_list, x)
        if lli == 0:
            if li == 0:
                return None
            return self.__lists[li - 1][-1]
        return self.__lists[li][lli - 1]

    def ge(self, x):
        if self.__len == 0:
            return None
        li = bisect_left(self.__maxes, x)
        if li == self.__block_count:
            y = self.__lists[-1][-1]
            return y if y == x else None
        _list = self.__lists[li]
        lli = bisect_left(_list, x)
        return self.__lists[li][lli]

    def gt(self, x):
        if self.__len == 0:
            return None
        li = bisect_right(self.__maxes, x)
        if li == self.__block_count:
            return None
        _list = self.__lists[li]
        lli = bisect_right(_list, x)
        return self.__lists[li][lli]

    @property
    def max(self):
        return self[self.__len - 1]

    @property
    def min(self):
        return self[0]

    def __build_ft(self):
        self.__sizes_ft = self.FenwickTree(self.__sizes_list)

    def __pos(self, li, lli):
        return self.__sizes_ft.sum(li) + lli

    def __split(self, li):
        _list = self.__lists[li]
        sz = self.__sizes_list[li]
        self.__maxes.insert(li + 1, _list[-1])
        self.__lists.insert(li + 1, _list[self.__load:])

        del _list[self.__load:]
        self.__maxes[li] = _list[-1]

        self.__sizes_list[li] = self.__load
        self.__sizes_list.insert(li + 1, sz - self.__load)
        self.__build_ft()
        self.__block_count += 1

    def __marge(self, li):
        if li == 0:
            self.__lists[0].extend(self.__lists[1])
            self.__sizes_list[0] += self.__sizes_list[1]
            self.__maxes[0] = self.__maxes[1]
            del self.__lists[1]
            del self.__maxes[1]
            del self.__sizes_list[1]
            self.__block_count -= 1
            if self.__sizes_list[0] >= 2 * self.__load:
                return self.__split(0)
        else:
            self.__lists[li - 1].extend(self.__lists[li])
            self.__sizes_list[li - 1] += self.__sizes_list[li]
            self.__maxes[li - 1] = self.__maxes[li]
            del self.__lists[li]
            del self.__maxes[li]
            del self.__sizes_list[li]
            self.__block_count -= 1
            if self.__sizes_list[li - 1] >= 2 * self.__load:
                return self.__split(li - 1)
        self.__build_ft()

    def __len__(self):
        return self.__len

    def __contains__(self, x):
        return self.contains(x)

    def __getitem__(self, i):
        if i < 0:
            i += self.__len
        ft = self.__sizes_ft
        li = ft.upper_bound(i)
        lli = i - ft.sum(li)
        return self.__lists[li][lli]

    def get_all_values(self):
        ret = []
        for _list in self.__lists:
            ret.extend(_list)
        return ret


def main():
    Q = int(input())
    S = OrderedMultiList()
    for _ in range(Q):
        query = list(map(int, input().split()))
        q = query[0]
        if q == 1:
            x = query[1]
            S.add(x)
        elif q == 2:
            x, c = query[1:]
            for _ in range(c):
                f = S.discard(x)  # discard削除に成功したか返す
                if not f: break  # もうxがSに残っていないならbreak
        else:  # q == 3
            print(S.max - S.min)


if __name__ == '__main__':
    main()

D問題『FizzBuzz Sum Hard』

問題ページD - FizzBuzz Sum Hard
コーダー正解率:22.0 %
コーダー正解率:66.2 %
コーダー正解率:89.0 %

入力

$N,A,B$ : 整数

考察

$1$ 以上 $N$ 以下の整数の総和 $S_U$は、等差数列の和の公式より $S_U=\dfrac{N(N+1)}{2}$ です。

$A$ の倍数か $B$ の倍数であるものの総和を求められれば、$S_U$ から引くことで答えを求められます。

Aの倍数の総和を求める

$1$ 以上 $N$ 以下の整数に、$A$ の倍数は $\lfloor\dfrac{N}{A}\rfloor$ 個含まれます。($N$ を $A$ で割って小数点以下切り捨て)$A$ の倍数たちは、初項 $A$、公差 $A$ 、項数 $\lfloor\dfrac{N}{A}\rfloor$ の等差数列とみなせるので、$A$ の倍数の総和 $S_A$ は等差数列の和の公式を使って求められます。

$B$ の倍数の総和 $S_B$ も同じ手順で求められます。

そのまま引くとダブルカウント

$S_U - S_A - S_B$ が答えになりそうですが、これは不正解です。

$A$ の倍数でも $B$ の倍数でもある整数を二重に引いてしまっているからです。$A, B$ 両方の倍数である整数のことを、$A$ と $B$ の公倍数といいます。そして、公倍数の中で最も小さいものを**最小公倍数(LCM)**といいます。

$A$ と $B$ の最小公倍数を $L$ とします。$1$ 以上 $N$ 以下の $L$ の倍数の総和を、$A,B$ と同様に求めて、足してあげればダブルカウントが解消されます。

最小公倍数は $L=\dfrac{A\times{B}}{\gcd(A,B)}$ で求められます。ただし、$\gcd(A,B)$ は $A$ と $B$ の最大公約数です。最大公約数は、mathモジュールのgcd関数で求められます。

答えは、$S_U-S_A-S_B+S_L$ です。

コード

from math import gcd


def aseq_sum(a1, d, n):
    return n * (2 * a1 + (n - 1) * d) // 2


def calc(x):
    return aseq_sum(x, x, N // x)


N, A, B = map(int, input().split())
lcm = (A * B) // gcd(A, B)
ans = aseq_sum(1, 1, N) - calc(A) - calc(B) + calc(lcm)
print(ans)

E問題『Distance Sequence』

問題ページE - Distance Sequence
コーダー正解率:2.5 %
コーダー正解率:12.5 %
コーダー正解率:48.6 %

考察

動的計画法(DP)で解きます。

愚直DP

まず愚直なDPを考えます。

$\mathrm{dp}[i][j]$ : 数列の長さが $i+1$ で、末尾が $j$ である、条件を満たす数列の数

と状態を定義し、$\mathrm{dp}[0][j]=1$($1\le{j}\le{M}$)で初期化します。

遷移は、$1$ ~ $M$ まですべて確認し、問題文の条件を満たすなら足すようにします。

これを実装すると、以下のコードになります。

MOD = 998244353

N, M, K = map(int, input().split())
dp = [[0] * (M + 1) for _ in range(N)]

for i in range(1, M + 1):
    dp[0][i] = 1

for i in range(N - 1):
    for j in range(1, M + 1):
        for k in range(1, M + 1):
            if abs(k - j) >= K:
                dp[i + 1][j] += dp[i][k]
                dp[i + 1][j] %= MOD
print(sum(dp[-1]) % MOD)

このコードは正しい答えを返しますが、計算量が $O(NM^2)$ となり間に合いません。

累積和で高速化

条件文を見ると、ある$A_{i+1}$ に対して、条件を満たさない $A_i$ の値が連続していることがわかります。

そこで、累積和を利用して条件を満たさない区間和を $O(1)$ で計算し、全体の総和から引いて求めることにします。

実装

ある長さ $i$ のdp配列をすべて求めたら、list(accumulate(dp))で累積和配列に変換します。累積和配列の一番後ろの値が、全体の総和です。

条件を満たさない区間は $|A_i-A_{i+1}|\lt{K}$、すなわち $A_{i+1}-(K-1)\le{A_i}\le{A_{i+1}+(K-1)}$ です。ここの区間和を累積和で求め、全体の総和から引けばDP配列を埋めていけます。

ただし、$K=0$ のとき、区間の左端と右端が逆転して壊れてしまうので、特別に場合分けするといいです。$K=0$ ならば使える整数に制限はありませんから、答えは $M^N$ です。

コード

from itertools import accumulate

MOD = 998244353


def main():
    def solve():
        N, M, K = map(int, input().split())
        if K == 0:
            return pow(M, N, MOD)

        prev = [1] * (M + 1)
        prev[0] = 0
        prev = list(accumulate(prev))
        for _ in range(N - 1):
            S = prev[-1] % MOD
            dp = [0] * (M + 1)
            for j in range(1, M + 1):
                l = max(1, j - (K - 1))
                r = min(M, j + (K - 1))
                t = prev[r] - prev[l - 1]
                dp[j] = (S - t) % MOD
            prev = list(accumulate(dp))
        return prev[-1] % MOD

    print(solve())


if __name__ == '__main__':
    main()

セグメントツリーで区間和を求める方法

累積和の代わりに、セグメントツリーで区間和を求めてもいいです。やってることは変わりません。

from operator import add


class SegmentTree:
    def __init__(self, op, e, n=0, *, array=None):
        assert (n > 0 and array is None) or (n == 0 and array)
        self.e = e
        self.op = op
        self.n = n if n else len(array)
        self.log = (self.n - 1).bit_length()
        self.size = 1 << self.log
        self.d = [e] * (2 * self.size)

        if array:
            for i in range(self.n):
                self.d[self.size + i] = array[i]
            for i in reversed(range(1, self.size)):
                self.__update(i)

    def set(self, p, x):
        """a[p] に x を代入する"""
        p += self.size
        self.d[p] = x
        for i in range(1, self.log + 1):
            self.__update(p >> i)

    def get(self, p):
        """a[p]を返す"""
        return self.d[p + self.size]

    def prod(self, l, r):
        """[l, r) の総積を返す"""
        if l == 0 and r >= self.n:
            return self.all_prod()

        op = self.op

        sml = self.e
        smr = self.e

        l += self.size
        r += self.size

        while l < r:
            if l & 1:
                sml = op(sml, self.d[l])
                l += 1
            if r & 1:
                r -= 1
                smr = op(self.d[r], smr)
            l >>= 1
            r >>= 1
        return op(sml, smr)

    def all_prod(self):
        """[0, n) の総積を返す"""
        return self.d[1]

    def max_right(self, l, f):
        """f(op(a[l:r])) == True となる最大のr(fが単調ならば)"""
        if l == self.n:
            return self.n
        op = self.op
        size = self.size
        l += size
        sm = self.e

        while True:
            while not (l & 1):
                l >>= 1
            if not f(op(sm, self.d[l])):
                while l < size:
                    l <<= 1
                    if f(op(sm, self.d[l])):
                        sm = op(sm, self.d[l])
                        l += 1
                return l - size
            sm = op(sm, self.d[l])
            l += 1
            if (l & -l) == l:
                break
        return self.n

    def min_left(self, r, f):
        """f(op(a[l:r])) == True となる最小のl(fが単調ならば)"""
        if r == 0:
            return 0
        op = self.op
        size = self.size
        r += self.size
        sm = self.e

        while True:
            r -= 1
            while r and r & 1:
                r >>= 1
            if not f(op(self.d[r], sm)):
                while r < size:
                    r = 2 * r + 1
                    if f(op(self.d[r], sm)):
                        sm = op(self.d[r], sm)
                        r -= 1
                return r + 1 - size
            sm = op(self.d[r], sm)
            if (r & -r) == r:
                break
        return 0

    def __update(self, k):
        self.d[k] = self.op(self.d[k << 1], self.d[k << 1 | 1])

    def __getitem__(self, key):
        if isinstance(key, slice):
            start, stop = key.start, key.stop
            if start is None: start = 0
            if stop is None: stop = self.n
            return self.prod(start, stop)
        else:
            return self.d[key + self.size]

    def __setitem__(self, key, value):
        self.set(key, value)


MOD = 998_244_353


def main():
    N, M, K = map(int, input().split())
    I = [1] * (M + 1)
    I[0] = 0
    prev = SegmentTree(add, 0, array=I)
    for _ in range(N - 1):
        dp = [0] * (M + 1)
        S = prev[:] % MOD
        for i in range(1, M + 1):
            l = max(1, i - K + 1)
            r = min(M, i + K - 1)
            dp[i] = (S - prev[l:r + 1]) % MOD
        prev = SegmentTree(add, 0, array=dp)
    print(prev[:] % MOD)


if __name__ == '__main__':
    main()
20
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
8