4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

atcoder よく使う関数, アルゴリズムのリスト, その他高速化の為のtipsなど

Last updated at Posted at 2020-05-25

atcoderでよく(?)使うアルゴリズムや関数のリスト.
備忘録, コピペ用関数など. 勉強しながら更新します.

#基本的な処理

##辞書関連
#####辞書の要素をリストにして管理する.
良く忘れるので.

N =  int(input())
dic = {}
for i in range(N):
  x, y = map(int, input().split())
  dic.setdefault(x, []).append(y)

for 文を回すときは
dic : キー
dic.values() : 要素
dic.items() : キーと要素

二分探索

pythonの標準ライブラリでは降順の数列に対する二分探索は含まれていない.
しかし例えば

「最小の値が出てきたときにその値を数列に加え, そうでない場合は数列の対応する要素を更新する」
(atcoder ABC134のE問題で数列の先頭から見て解く場合このような処理が使えるかと思います. 最適な解き方ではないかもしれないですが. )

という作業をpythonのinsert関数で実装するとinsertで$O(N)$の計算量になってしまい, TLEにつながる.

その為appendを使用できるように降順の数列に対して二分探索できると便利である. と考えたら以下のサイトに既に実装してくださっている方がいました!

降順リストに対するbisectの実装 list.sort(reverse=True)に対する配列二分法

#高速化
TLEする際に以下の変更を試みると解消できるかもリスト.
##出力で遅い
あんまり出力について気にして無かったが, for分の中で毎回print()すると思っている以上に計算速度に影響している可能性がある.
ex)
ABC089 D問題
上の書き方 : 約 1000ms前後
下の書き方 : 約 500ms前後
と500ms近い差でした.

#遅い出力
for i in range(N):
  #処理
  print("計算結果")
#速い出力
ans = []
for i in range(N):
  #処理
  ans.append("計算結果")
print(*ans, sep = "\n")

##桁数が大き過ぎる問題
pythonは多倍長整数という整数型で数字の桁が大きくなりすぎてオーバーフローしそうになると勝手にメモリを確保してくれる.
これは非常に便利だが, atcoderにおいてはTLEを生む要因になる.

こうなった時は次の3通りの対処を試すとよい.

  1. MODがある場合はループ中でこまめにMODを取る.
  2. 例えば数字の桁が増える際 10倍 + 1の位 のような形で分解して演算する.
  3. 桁数が莫大になる前にループから抜ける.

1 + 2 : ABC164 D問題
3 : ABC169 C問題

numpy vs list

脳みそが整理されたら書く.

#アルゴリズム

##bit全探査(指数関数のオーダーでの全探査)
問題例 atcoder ABC119:C, ABC147:Cなど

$N < 10$程度(底が2なら15くらいまで可能?)なら有効. 2進数ならシフト演算での実装するのが普通そう
参考文献
bit全探査のアルゴリズム

n = int(input())
bit_base = 4#bit_base^nの全探査になる. 
def Base_10_to_n(X, n):#10進数をbit_base進数に変換
    X_dumy = X
    out = ''
    while X_dumy>0:
        out = str(X_dumy%n)+out
        X_dumy = int(X_dumy/n)
    return out
for i in range(bit_base**n):
  s = Base_10_to_n(i, bit_base)
  s = s.zfill(n)
  for i, bit in enumerate(s):
    #問題に応じた処理を書く
print(ans)

dfs(深さ優先探索)

問題によって実装は変わってくるのでそこら辺は
よくやる再帰関数の書き方 〜 n 重 for 文を機械的に 〜
DFS (深さ優先探索) 超入門! 〜 グラフ・アルゴリズムの世界への入口 〜【前編】

などを読めばよい.
ここではpythonでdfsを書く際には再起関数の深さ上限を変更しないと大体の場合REすることになるのでその点の備忘録(私はよく忘れる)
コードはABC070 D問題の際のものの抜粋.

ちなみにdfsでWAやREが生じたときには

  1. 再起関数の繰り返し条件変えてあるか.(python限定. 以下のコードの上2行参照)
  2. 条件分岐でうっかりreturn させるのを忘れている箇所がないか
  3. 一方通行になっているか. (dfs(1)→dfs(2)→dfs(1)のような無限ループが生じる形になっていないか.)
    の3点を確認してからアルゴリズムを疑った方が良い(個人的な見解).
import sys
sys.setrecursionlimit(10**8)
#ここまでの2行を忘れてはいけない

N = int(input())
a = [[] for i in range(N)]
for i in range(N-1):
  A, B, C = map(int, input().split())
  a[A-1].append([B-1, C])
  a[B-1].append([A-1, C])
dist = [-1]*N

def dfs(now):
  for i, dis_e in a[now]:
    if dist[i] == -1:
      dist[i] = dist[now] + dis_e
      dfs(i)

nCr をpで割った余りの算出の高速化

よく使う
問題例 : atcoder ABC151:E, ABC145:D など

公式

$a$は$b$で割り切れるとする.
$$ pが素数の時,\ a / b \equiv a * b^{(p-2)} \ (mod \ p) $$
が成り立つ.

  • $a$を$b$で割って得られる整数を$p$で割った余りを求めるという作業は $b^{(p-2)}$をかけて$p$で割った余りを求めるという作業で代替できる.
#階乗のリストを作成する
frac = [1]
for i in range(N):
    frac.append(frac[i]*(i+1)%p)

#上の公式を用いてcombの計算を行う
def comb(n, k, mod):
    a=frac[n]
    b=frac[k]
    c=frac[n-k]
    return (a * pow(b, mod-2, mod) * pow(c, mod-2, mod)) % mod

union find

要素をいくつかの木構造で分けて管理するような場合に有効. (ABC 157 D問題とか)
参考文献:🐜ホン


class union_find:
  def __init__(self, n):
    self.par = [-1] * n
  
  def find(self, x):#xの親を見つける
    if self.par[x] < 0:
      return x
    else:
      self.par[x] = self.find(self.par[x])
      return self.par[x]

  def unite(self,x,y):#要素xとyを併合させる
    x,y=self.find(x),self.find(y)#xとyの親の検索
    if x!=y:#親が異なる場合併合させる
      if x>y:
        x,y=y,x#小さい方をxとする. これにより要素の値が小さいものを優先して木の根とする. 
      self.par[x]+=self.par[y] #値を無向木の要素数の和にする.
      self.par[y]=x #枝側は根の位置を格納

  def same(self, x, y):#要素xと要素yが同じ無向木に所属しているかを判定する
    return self.find(x) == self.find(y)#同じ値を持つか否か

  def size(self, x):#要素xが所属する無向木の大きさを返す
    return-self.par[self.find(x)] 

##グラフ系
###ダイクストラ法
ある始点から残りの全頂点への距離を求める. ここでの実装は$O(V^2)$だがheapを使うと$O(V \log E)$にすることもできる. ($E$は辺の数)

cost = [] #iからjまでのコストを表す行列
INF = 10**9
V = "頂点数"
d = [INF]*V
used = [False] *V
def shortest_path(start = 1):#隣接行列の場合のダイクストラ法
  d[start] = 0
  while True:
    v = -1
    for u in range(V):
      if (not used[u]) and (v == -1 or d[u] < d[v]):v = u
    if v == -1:break
    used[v] = True
    for u in range(V):
      d[u] = min(d[u], d[v] + cost[u][v])#cost配列の定義によってu, vは逆かも

###ワーシャルフロイド法
グラフ上の任意の2点間の最短距離を全て求める. $N$を頂点数として$O(N^3)$で計算できる.


N, M = map(int, input().split())
inf = 10**9
d = [[inf]*N for _ in range(N)] #d[i][j]はノードiからノードjの最短距離を保持する. 

for i in range(M):
  a, b, c = map(int, input().split())
  d[a-1][b-1] = c
  d[b-1][a-1] = c#ノードiからjまでの距離を代入する. 
for i in range(N):
  d[i][i] = 0#iからiは自明に0なので0で初期化

def warshall_floyd():
  for k in range(N):
    for i in range(N):
      for j in range(N):
        d[i][j] = min(d[i][j], d[i][k] + d[k][j])

##整数
###素因数分解
素因数分解する. コードが冗長な感じがする. どこかもっと短くしたい.

def prime_factorize(n):
    prime = {}
    while n % 2 == 0:
        if prime.get(2):
          prime[2] += 1
        else:
          prime[2] = 1
        n //= 2
    f = 3
    while f * f <= n:
        if n % f == 0:
            if prime.get(f):
              prime[f] += 1
            else:
              prime[f] = 1
            n //= f
        else:
            f += 2
    if n != 1:
        if prime.get(n):
          prime[n] += 1
        else:
          prime[n] = 1
    return prime

行列の累乗

繰り返し2乗法で計算することで行列$A$の$N$乗を$O(\log N)$で計算する。

実装時の注意点

numpy形式で行列を定義して@の演算を使った実装の方が簡潔に実装できると思いたいところだが、numpyはC言語で実装されているのでオーバーフローになるときがある。そのためlistで行列の積を計算する関数を実装する必要がある。

def mat_mul(a, b, mod) :
    I, J, K = len(a), len(b[0]), len(b)
    c = [[0] * J for _ in range(I)]
    for i in range(I) :
        for j in range(J) :
            for k in range(K) :
                c[i][j] += a[i][k] * b[k][j]
            c[i][j] %= mod
    return c

def calc_matrix_pow(matrix, n, mod):
    '''
    行列のN乗を計算して、各要素をmodで割った余りを返す。
    '''
    result = [[0]*len(matrix) for _ in range(len(matrix))]
    for i in range(len(matrix)):
        result[i][i] = 1

    while n:
        if n & 1:
            result = mat_mul(matrix, result, mod)
        n = n >> 1
        matrix = mat_mul(matrix, matrix, mod)
    return result
4
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?