モチベーション
最近(今更)AtCoderを本格的に始めた者です。
AtCoder Beginner Contestで300~400点程度の問題をしっかり取れるようになることを目標に頑張っています。
解けなかった問題に対して解説を読んで復習をするときに、
「この解説と同じような理論で実装してるはずなのにTLEになってしまう」
「やりたいことはわかるけど書かなければいけないプログラミングの行数が多すぎる」
ってことありませんか?筆者は滅茶苦茶ありました。
そこで、「実はこうすると速いんだよ」っていうTipsをここでまとめておきます。ほぼ筆者の備忘録用ですが皆さんの参考になれば幸いです。
ビット演算
N個のものがそれぞれあるかないか、本当かウソか、真か偽かなどを扱う問題で$2^N$通りを全探索する場合などがあります。このとき、筆者はこんな面倒くさいコードを書いてしまっていました↓
lst = []
for i in range(2**N):
s = bin(i)[2:] # bin(2) = '0b10'
while(len(s) < N):
s += '0' + s
lst.append(list(map(int,list(s))))
(中略)
# M個目のケースのうちK個目の真偽値を確かめるとき
print(lst[M][K])
ここで役立ってくるのがビット演算です。
詳しい計算方法や仕組みについてはこちらの記事に書かれていましたのでここでは省略しますが、基本的な演算方法だけ下に書いておきます。
a << b # a * (2**b) と同義
a >> b # a // (2**b) と同義
(a >> b) & 1 # aを2進数にしたとき、下からb桁目が1か否か
これを使うとわざわざ文字列操作とかを行わなくても何も考えずに計算できそうです。
以下は試しにとあるAtCoder Beginner Contestの過去問をビット演算を使って解くか、普通に割り算して解くかで時間を比較してみた例です。
#https://atcoder.jp/contests/abc014/tasks/abc014_2
from time import time
n,X = list(map(int,input().split()))
A = list(map(int,input().split()))
start = time()
s = 0
for i in range(n):
if ((X >> i) & 1):
s += A[i]
print(s)
print(f"ビット演算の場合:{time()-start:.16f}")
start = time()
s = 0
for i in range(n):
if X%2==1:
s += A[i]
X = X//2
print(s)
print(f"普通に割り算した場合:{time()-start:.16f}")
出力結果:
210
ビット演算の場合:0.0000457763671875
210
普通に割り算した場合:0.0000114440917969
この場合はビット演算よりも普通に割り算した方が早い結果になりました。。。
ただ、割り算や掛け算をしすぎるとビット演算の方が速い場合があるみたいなので、臨機応変に使っていくのがいいかもしれません。
関連Qiita
- Python ビット演算 超入門: ビット演算の詳しい計算方法については説明してくださっています
- こわくないbit全探索1 入門編: bit全探索ってなに?【競プロ解説】 : 例題や説明などが分かりやすく書かれています
- 【AtCoder】ビット全探索〜問題まとめましたっ!〜 : ビット全探索を使う問題を集めてくださってます
配列について
$i$に対応する数を数えて$C_i$に格納するような配列$C$を扱う場合、初期値の定義に以下の2つの方法があります。
# 1. 単純配列を使う場合
[0] * N
# 2. numpyを使う場合
import numpy as np
np.zeros(N)
大きな配列、特に2次元配列などを扱う場合や行列計算を用いる場合はnumpy一択ですが、一次元配列のみで十分の場合は1.の方法を使った方が早い場合があります。
from time import time
import numpy as np
# 配列のサイズ
N = 1000000
M = 10000
# ランダムな整数配列を生成
array = np.random.randint(0, M+1, size=N)
print("arrayがnumpy配列の場合")
start = time()
count = [0] * (M+1)
for i in range(N):
count[array[i]] += 1
print(time()-start) # 0.1425337791442871
start = time()
count = np.zeros(M+1)
for i in range(N):
count[array[i]] += 1
print(time()-start) # 0.24925613403320312
array = array.tolist()
print("arrayが通常配列の場合")
start = time()
count = [0] * (M+1)
for i in range(N):
count[array[i]] += 1
print(time()-start) # 0.07651591300964355
start = time()
count = np.zeros(M+1)
for i in range(N):
count[array[i]] += 1
print(time()-start) # 0.1901106834411621
print("numpy関数を使ったウラワザ")
start = time()
count = np.bincount(array)
print(time()-start) # 0.020976781845092773
(追記 2024/6/2)
numpyを使う場合はできるだけnumpyの関数を使って計算する手法を取ればそっちのほうが圧倒的に速いらしいのですが、そのような関数を覚えていない場合や計算が複雑でどうしてもforループしなければならない場合は配列の方が速いみたいですね。(コメントありがとうございます。)
理由はまだちゃんと調べられてはいないですが、ChatGPTに聞いたところnumpyはC言語のメモリ管理関数を呼び出しているので少し遅くなるらしいです。
functools.cacheについて
functools.cacheを使うと既に計算された出力値が保存されるので、再帰関数の計算が圧倒的に速くなる。
from functools import cache
import numpy as np
from time import time
# 2^M を modで割った値を出す
M = 10 ** 7
mod = 998244353
def pow_a(M):
s = 1
for i in range(M):
s *= 2
s %= mod
return s
def pow_recursive(M):
if M==0:
return 1
elif M==1:
return 2
else:
return (pow_recursive(M//2) * pow_recursive(M-M//2)) % mod
@cache
def pow_recursive_cache(M):
if M==0:
return 1
elif M==1:
return 2
else:
return (pow_recursive_cache(M//2) * pow_recursive_cache(M-M//2)) % mod
print("普通に計算する場合")
start = time()
ans = pow_a(M)
print(f"{ans},{time()-start}s") # 1.1814076900482178s
print("再帰的に計算する場合")
start = time()
ans = pow_recursive(M)
print(f"{ans},{time()-start}s") # 2.6318399906158447s
print("functools.cacheを使用する場合")
start = time()
ans = pow_recursive_cache(M)
print(f"{ans},{time()-start}s") # 0.0s