C - Filling 3x3 array
全探索で解いていきます。
1行目をrow1、2行目をrow2。1行目の1列目のデータをrow1[1]、2列目のデータをrow1[2]と表記します。
h1 = row1[1] + row1[2] + row1[3]
よりrow1に関して、
row1[1]とrow1[2]が決まれば、row1[3] = h1 - (row1[1] + row1[2])
と自動的に決まります。
row2についても同様に1列目・2列目が決まれば自動的に決まります。
row3について、w1 = row1[1] + row2[1] + row3[1]
より
row3[1] = w1 - (row1[1] + row2[1])
と自動的に決まります。2列目・3列目についても同様です。
実装については、forの4重ループは自分が混乱しそうだったので、
合計でxになる正整数3つの組み合わせ
を返す関数をget_value_combination
として定義し、row1、row2のパターンを全列挙しそれぞれforループで回しています。
def main():
A = list(map(int, input().split()))
h = A[:3]
w = A[3:]
cnt = 0
# row1のパターンを全列挙
for row1 in get_value_combination(h[0]):
# row2にパターンを全列挙
for row2 in get_value_combination(h[1]):
row3 = [0] * 3
# row3は自動的に決まる
for i in range(3):
row3[i] = w[i] - row1[i] - row2[i]
# row3が0以下になる場合は次のパターンを試す
if row3[i] <= 0:
break
else:
# 上記のループが完了し、3列目の合計がh3に一致する場合、カウントに+1
if sum(row3) == h[2]:
cnt += 1
print(cnt)
# 合計がxになる正整数3つの組み合わせのパターンを返す関数
def get_value_combination(x: int):
res = []
# 2・3列目にそれぞれ少なくとも1は入るので,1列目の数値の範囲は1~x-2
for i in range(1, x - 1):
# 2列目の数値の範囲は1~x-i-1
for j in range(1, x - i):
# 3列目はx-i-j
res.append((i, j, x - i - j))
return res
if __name__ == "__main__":
main()
D - Union of Interval
まず入力(L,R)をLについてソートします。
ある時点での暫定的な区間=(X,Y)
とすると、ソートしたことでX<=L
となります。
あとは以下の3パターンについて処理していきます。
- 現在の右端YよりLが大きい
Y<L
=区間が離れている
->それ以降重なった区間はないので(X,Y)を区間として確定し、(L,R)を新たな暫定区間(X,Y)とする - 現在の右端YよりLが小さいか同じ
X<=L<=Y
=区間が重なっている- 現在の右端YよりRが大きい
X<=L<=Y<R
=暫定区間からはみ出している
->暫定区間を拡張しRを新たなYとする - 現在の右端YがRより小さいor同じ
X<=L<=R<=Y
=暫定区間に完全に含まれている
->何もしない
- 現在の右端YよりRが大きい
def main():
N = int(input())
zones = []
for _ in range(N):
zones.append(tuple(map(int, input().split())))
# 入力をLについてソート
zones.sort()
# 先頭の値を初期値に設定
x, y = zones[0]
ans = []
for l, r in zones:
# lが現在の右端yより大きい場合=別の区間となる
# 現在の(x,y)を区間として確定し答えに追加。(l,r)を新しい区間として設定
if l > y:
# 確定した区間(x,y)を出力用にstr型で答えに格納
ans.append(f"{x} {y}")
x = l
y = r
continue
# lが現在の区間の中にあり(x<=l<=y)、rが現在の左端yより大きい場合、
# 現在の区間を拡張
if r > y:
y = r
else:
# 最後に残った区間を確定
ans.append(f"{x} {y}")
print("\n".join(ans))
if __name__ == "__main__":
main()
E - Takahashi's Anguish
コンテスト中に思いついた方法はセグメント木を用いる方法でした。
コンテスト中はセグメント木を直接変更していましたが、後から考えるとsegfuncと単位元eを適切に設定すればセグメント木を変更する必要はありませんでした…
処理としては、ある人Aを数列Pに加えた時に生じる不満度の合計total_fuman
と不満度の合計が最も低い人のインデックスをセグメント木で管理しています。
トポロジカルソートの、出次が0のノードを選ぶ部分を他者からの不満度が最も少ない人を選ぶようにしている感じです。
以下をN回繰り返します。
- セグメント木から現時点でPに加えた時の
不満度が最も少ない人A
を選択 - Aをセグメント木から削除し、不満度
total_fuman[A]
を答えに加算 - Aは自分より後ろの人に対して不満を持つことはないので
嫌いな人X[A]
の合計の不満度からAが持つ不満度C[A]
を引く - セグメント木を更新し1に戻る
セグメント木の構築がNlogN, 更新がlogNでそれをN回なので、全体の計算量はNlogN
だと思います。UnionFindを使った模範解答より遅いですが、ACできました。
import sys
INF = float("inf")
def segfunc(x, y):
# x, yは(不満度の合計, index)の形式
# 不満度が小さい方を返す
if x[0] < y[0]:
return x
else:
return y
class SegTree:
def __init__(self, x_list, init, segfunc):
self.length = len(x_list)
self.init = init
self.segfunc = segfunc
self.Height = len(x_list).bit_length() + 1
self.Tree = [init] * (2**self.Height)
self.num = 2 ** (self.Height - 1)
for i in range(len(x_list)):
self.Tree[2 ** (self.Height - 1) + i] = x_list[i]
for i in range(2 ** (self.Height - 1) - 1, 0, -1):
self.Tree[i] = segfunc(self.Tree[2 * i], self.Tree[2 * i + 1])
def __getitem__(self, i):
return self.select(i)
def __len__(self):
return self.length
def __str__(self):
return str(self.Tree[self.num : self.num + self.length])
def select(self, k):
"""
k番目の要素を取得する
Args:
k (int): 取得する要素のインデックス(0-index)
Returns:
k番目の要素
"""
return self.Tree[k + self.num]
def update(self, k, x):
"""
k番目の要素をxに更新する
Args:
k (int): 更新する要素のインデックス(0-index)
x : k番目に入る新たな値
"""
i = k + self.num
self.Tree[i] = x
while i > 1:
if i % 2 == 0:
self.Tree[i // 2] = self.segfunc(self.Tree[i], self.Tree[i + 1])
else:
self.Tree[i // 2] = self.segfunc(self.Tree[i - 1], self.Tree[i])
i //= 2
def query(self, l, r):
"""
半開区間[l:r)についてsegfuncでの演算結果を返す
Args:
l (int): 区間の左端(範囲に含まれる)
r (int): 区間の右端(範囲に含まれない)
Returns:
[l:r)のsegfuncでの計算結果
"""
result = self.init
l += self.num
r += self.num
while l < r:
if l % 2 == 1:
result = self.segfunc(result, self.Tree[l])
l += 1
if r % 2 == 1:
result = self.segfunc(result, self.Tree[r - 1])
l //= 2
r //= 2
return result
input = sys.stdin.readline
def main():
N = int(input())
# Xを0-indexに変換
X = list(map(lambda x: int(x) - 1, input().split()))
C = list(map(int, input().split()))
# ある人が前にいるときの不満度の合計を計算
total_fuman = [0] * N
for i in range(N):
idx = X[i]
total_fuman[idx] += C[i]
# (不満度の合計, idx)の形でリストに格納
fuman_idx = [(fuman, i) for i, fuman in enumerate(total_fuman)]
# 不満度の合計とそのindexをセグメント木に格納
# セグメント木は区間[0, N)の不満度の最小値とそのindexを返す
# 単位元は無限, indexは適当に-1を入れておく
e = (INF, -1)
seg = SegTree(fuman_idx, e, segfunc)
ans = 0
for _ in range(N):
# 現時点で合計の不満度が最小である人とその不満度を取得
min_fuman, min_idx = seg.Tree[1]
# 不満度を答えに加算
ans += min_fuman
# その人をセグメント木から取り出す。処理としてはそのindexの要素を単位元にする
seg.update(min_idx, e)
# 消した人が嫌いな人に対して持っている不満度を減らしてセグメント木を更新
dislike = X[min_idx]
fuman = C[min_idx]
seg.update(dislike, (seg.select(dislike)[0] - fuman, dislike))
print(ans)
if __name__ == "__main__":
main()