2
0

競技プログラミングで使えるPython計算のTips

Last updated at Posted at 2024-06-02

モチベーション

最近(今更)AtCoderを本格的に始めた者です。
AtCoder Beginner Contestで300~400点程度の問題をしっかり取れるようになることを目標に頑張っています。

解けなかった問題に対して解説を読んで復習をするときに、

「この解説と同じような理論で実装してるはずなのにTLEになってしまう」
「やりたいことはわかるけど書かなければいけないプログラミングの行数が多すぎる」

ってことありませんか?筆者は滅茶苦茶ありました。

そこで、「実はこうすると速いんだよ」っていうTipsをここでまとめておきます。ほぼ筆者の備忘録用ですが皆さんの参考になれば幸いです。

ビット演算

N個のものがそれぞれあるかないか、本当かウソか、真か偽かなどを扱う問題で$2^N$通りを全探索する場合などがあります。このとき、筆者はこんな面倒くさいコードを書いてしまっていました↓

lst = []
for i in range(2**N):
    s = bin(i)[2:] # bin(2) = '0b10' 
    while(len(s) < N):
        s += '0' + s
    lst.append(list(map(int,list(s))))

(中略)

# M個目のケースのうちK個目の真偽値を確かめるとき
print(lst[M][K]) 

ここで役立ってくるのがビット演算です。
詳しい計算方法や仕組みについてはこちらの記事に書かれていましたのでここでは省略しますが、基本的な演算方法だけ下に書いておきます。

a << b # a * (2**b) と同義
a >> b # a // (2**b) と同義
(a >> b) & 1 # aを2進数にしたとき、下からb桁目が1か否か

これを使うとわざわざ文字列操作とかを行わなくても何も考えずに計算できそうです。

以下は試しにとあるAtCoder Beginner Contestの過去問をビット演算を使って解くか、普通に割り算して解くかで時間を比較してみた例です。

abc014_2.py
#https://atcoder.jp/contests/abc014/tasks/abc014_2
from time import time
n,X = list(map(int,input().split()))

A = list(map(int,input().split()))

start = time()
s = 0
for i in range(n):
    if ((X >> i) & 1):
        s += A[i]
        
print(s)   
print(f"ビット演算の場合:{time()-start:.16f}")

start = time()
s = 0
for i in range(n):
    if X%2==1:
        s += A[i]
    X = X//2
        
print(s)   
print(f"普通に割り算した場合:{time()-start:.16f}")

出力結果:

210
ビット演算の場合:0.0000457763671875
210
普通に割り算した場合:0.0000114440917969

この場合はビット演算よりも普通に割り算した方が早い結果になりました。。。
ただ、割り算や掛け算をしすぎるとビット演算の方が速い場合があるみたいなので、臨機応変に使っていくのがいいかもしれません。

関連Qiita

配列について

$i$に対応する数を数えて$C_i$に格納するような配列$C$を扱う場合、初期値の定義に以下の2つの方法があります。

# 1. 単純配列を使う場合
[0] * N

# 2. numpyを使う場合
import numpy as np
np.zeros(N)

大きな配列、特に2次元配列などを扱う場合や行列計算を用いる場合はnumpy一択ですが、一次元配列のみで十分の場合は1.の方法を使った方が早い場合があります。

sample.py
from time import time
import numpy as np

# 配列のサイズ
N = 1000000
M = 10000
# ランダムな整数配列を生成
array = np.random.randint(0, M+1, size=N)

print("arrayがnumpy配列の場合")

start = time()

count = [0] * (M+1)
for i in range(N):
    count[array[i]] += 1
    
print(time()-start) # 0.1425337791442871

start = time()

count = np.zeros(M+1)
for i in range(N):
    count[array[i]] += 1
    
print(time()-start) # 0.24925613403320312

array = array.tolist()

print("arrayが通常配列の場合")

start = time()

count = [0] * (M+1)
for i in range(N):
    count[array[i]] += 1
    
print(time()-start) # 0.07651591300964355

start = time()

count = np.zeros(M+1)
for i in range(N):
    count[array[i]] += 1
    
print(time()-start) # 0.1901106834411621

print("numpy関数を使ったウラワザ")

start = time()

count = np.bincount(array)

print(time()-start) # 0.020976781845092773

(追記 2024/6/2)
numpyを使う場合はできるだけnumpyの関数を使って計算する手法を取ればそっちのほうが圧倒的に速いらしいのですが、そのような関数を覚えていない場合や計算が複雑でどうしてもforループしなければならない場合は配列の方が速いみたいですね。(コメントありがとうございます。)
理由はまだちゃんと調べられてはいないですが、ChatGPTに聞いたところnumpyはC言語のメモリ管理関数を呼び出しているので少し遅くなるらしいです。

functools.cacheについて

functools.cacheを使うと既に計算された出力値が保存されるので、再帰関数の計算が圧倒的に速くなる。

from functools import cache
import numpy as np
from time import time

# 2^M を modで割った値を出す

M = 10 ** 7
mod = 998244353

def pow_a(M):
    s = 1
    for i in range(M):
        s *= 2
        s %= mod
    return s

def pow_recursive(M):
    if M==0:
        return 1
    elif M==1:
        return 2
    else:
        return (pow_recursive(M//2) * pow_recursive(M-M//2)) % mod

@cache
def pow_recursive_cache(M):
    if M==0:
        return 1
    elif M==1:
        return 2
    else:
        return (pow_recursive_cache(M//2) * pow_recursive_cache(M-M//2)) % mod

print("普通に計算する場合")
start = time()
ans = pow_a(M)
print(f"{ans},{time()-start}s") # 1.1814076900482178s

print("再帰的に計算する場合")
start = time()
ans = pow_recursive(M)
print(f"{ans},{time()-start}s") # 2.6318399906158447s

print("functools.cacheを使用する場合")
start = time()
ans = pow_recursive_cache(M)
print(f"{ans},{time()-start}s") # 0.0s
2
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0