目的
・エラトステネスの篩を自分で実装してライブラリ化する。
(他の方のコードを丸ごとコピーでもよいが理解はしておきたいため)
・エラトステネスの篩におけるメモリの課題と細かい高速化を頑張る。
前提
・プログラム言語:Python
エラトステネスの篩
N以下の素数列挙アルゴリズム。
手順は以下の通り(ネットのやつ見た方が早い)
手順1
1〜Nの整数のリストを作る。
また、最初は「1」以外は全て素数仮定してそれぞれの整数に『素数ラベル』を付与する。
手順2
2〜√Nまでで以下の操作を繰り返す。
整数を取り出す。
取り出した整数に「素数」ラベルがついている。→ 無視
取り出した整数に「素数じゃない」ラベルがついている。
➡︎k倍した箇所のラベルを「素数じゃない」に塗り替える。k倍がNを超えたら終わり。
手順3
「素数」ラベルがついている奴だけを取り出す。
これがN以下の素数の集合体
問題点
メモリ消費が激しい
N以下の整数を一括で保持するのでNが大きいとメモリエラーが起きてしまう。。。
→√Nごとに分割してメモリ消費量を減らす!
効率化
偶数は素数でないことはわかりきっているのだから1〜Nの奇数だけをリストアップするようにする。
単純計算でNが半分になるから計算量も半分になるのでは...?(←よくわかっていない。)
以上を踏まえた手順
手順1
まずは、1〜√Nで存在する素数を列挙する。
1〜√Nのうち、奇数のみをリストアップして下から倍数の部分を消していく。
手順2
√N回の以下の操作を繰り返す。
[√N,2√N]、[2√N,3√N]...[(√N-1)√N,√N]
のように探索の範囲を√Nごとに分割してエラトステネスの篩を使う。
√Nの区間が終わるたびに、「素数」ラベルの部分を取り出して結果格納リストに結合。(今回はsetを使った)
これでメモリエラーは解決した。
工夫点
①:範囲内の偶数はもちろん弾く
②:素数の倍数計算は奇数のみ
(整数リストアップの段階で2の倍数を弾いているので偶数をかける意味はない)
③:素数の倍数は全ての範囲ごとに共通なので√N以下の素数に対応した倍数のリストを保持しておく。
④:プログラムにおいて、割り算(mod)の計算が一番遅いらしい
今回はmodを一切使わないという制約を自分に課してプログラムを制作した
2の倍数は 『mod 2 != 0』ではなく、『x^2 - 1』で計算するなど
上記方針を元にPythonで書いたもの
コードがバグだらけだったので修正しました
また、テストの項を追記して妥当性を検証しました
def enumeration_primenumber(n:int):
import math as m
result_p_set = {2} #素数リスト格納用set
## 最初の√Nの範囲での素数を求める
root_n = m.floor(n**0.5)
if n == 1:
return {}
if n == 2:
return result_p_set
if n == 3:
return {2,3}
elif root_n < 3:
p_list = [3]
else:
p_list = [x*2 -1 for x in range(2,m.ceil((root_n)/2)+1)]
n_min = p_list[0]
n_max = p_list[-1]
mult_list = [] #各素数ごとの倍数を保持する変数
temp = []
for i in range(len(p_list)):
num = p_list[i]
if num == False:
continue
mult = 2
while True:
num_mult = num * (mult*2 - 1)
if num_mult > n_max:
mult_list.append(mult)
break
k = m.floor((num_mult - n_min)*(1/2))
p_list[k] = False
temp.append(k)
mult += 1
# 一旦√Nまでの素数をソート&False削除
temp = sorted(temp,reverse=True)
tmp = 0
for i in range(len(temp)):
index = temp[i]
if index == tmp:
continue
del p_list[index]
tmp = index
result_p_set = result_p_set.union(set(p_list))
## √Nまでに現れた素数の倍数を消していく⇒√Nごとに分割
for i in range(1,root_n+1):
if i == root_n:
num_list = [x*2 -1 for x in range(m.floor((root_n*i + 1)/2 + 1),m.ceil((n)/2+1))]
if len(num_list) == 0: break
else:
num_list = [x*2 -1 for x in range(m.floor((root_n*i + 1)/2 + 1),m.ceil((root_n*(i+1))/2)+1)]
n_min = num_list[0]
n_max = num_list[-1]
for j in range(len(p_list)):
p = p_list[j]
mult = mult_list[j]
while True:
num_mult = p * (mult*2 - 1)
if num_mult > n_max:
mult_list[j] = mult
break
k = m.floor((num_mult - n_min)*(1/2))
if k>=0 :
num_list[m.floor(k)] = False
mult += 1
result_p_set = result_p_set.union(set(num_list)) #列挙された素数を結果格納setにぶち込む
result_p_set.discard(False)
sorted_result_p_set = sorted(result_p_set)
return sorted_result_p_set
コードテスト
一番大事なテストとなります。
今回は
①:素数の数が正しいか
②:Atcoderの問題を解いてみてACするか
で検証します。
素数の数テスト
ネットで調べると、10^xごとの素数の数は出てきたのでそれで比較検証します。
print(len(enumeration_primenumber(10)))
print(len(enumeration_primenumber(100)))
print(len(enumeration_primenumber(1000)))
print(len(enumeration_primenumber(10000)))
print(len(enumeration_primenumber(100000)))
print(len(enumeration_primenumber(1000000)))
print(len(enumeration_primenumber(10000000)))
ネットで調べた結果と合致していた。
Atcoderの簡単な問題で試す
簡単なAtcoderの素数問題を解いてみます。
A - 与えられた数より小さい素数の個数について
https://atcoder.jp/contests/tenka1-2012-qualC/tasks/tenka1_2012_9
注意点は、与えられた数「未満」というところに気を付けて上の関数を使って提出します。
n = int(input())
def enumeration_primenumber(n:int):
import math as m
result_p_set = {2} #素数リスト格納用set
## 最初の√Nの範囲での素数を求める
root_n = m.floor(n**0.5)
if n == 1:
return {}
if n == 2:
return result_p_set
if n == 3:
return {2,3}
elif root_n < 3:
p_list = [3]
else:
p_list = [x*2 -1 for x in range(2,m.ceil((root_n)/2)+1)]
n_min = p_list[0]
n_max = p_list[-1]
mult_list = [] #各素数ごとの倍数を保持する変数
temp = []
for i in range(len(p_list)):
num = p_list[i]
if num == False:
continue
mult = 2
while True:
num_mult = num * (mult*2 - 1)
if num_mult > n_max:
mult_list.append(mult)
break
k = m.floor((num_mult - n_min)*(1/2))
p_list[k] = False
temp.append(k)
mult += 1
# 一旦√Nまでの素数をソート&False削除
temp = sorted(temp,reverse=True)
tmp = 0
for i in range(len(temp)):
index = temp[i]
if index == tmp:
continue
del p_list[index]
tmp = index
result_p_set = result_p_set.union(set(p_list))
## √Nまでに現れた素数の倍数を消していく⇒√Nごとに分割
for i in range(1,root_n+1):
if i == root_n:
num_list = [x*2 -1 for x in range(m.floor((root_n*i + 1)/2 + 1),m.ceil((n)/2+1))]
if len(num_list) == 0: break
else:
num_list = [x*2 -1 for x in range(m.floor((root_n*i + 1)/2 + 1),m.ceil((root_n*(i+1))/2)+1)]
#print(num_list)
n_min = num_list[0]
n_max = num_list[-1]
for j in range(len(p_list)):
p = p_list[j]
mult = mult_list[j]
while True:
num_mult = p * (mult*2 - 1)
#print(f"p_list:{p_list},num_mult:{num_mult}")
if num_mult > n_max:
mult_list[j] = mult
break
k = m.floor((num_mult - n_min)*(1/2))
#print(f"num_list:{num_list},n_min:{n_min},k:{k}")
if k>=0 :
num_list[m.floor(k)] = False
mult += 1
result_p_set = result_p_set.union(set(num_list)) #列挙された素数を結果格納setにぶち込む
result_p_set.discard(False)
sorted_result_p_set = sorted(result_p_set)
return sorted_result_p_set
ans = enumeration_primenumber(n)
if len(ans) == 0:
print(0)
elif max(ans) == n :
print(len(ans)-1)
else:
print(len(ans))
無事ACとなりました。
出力スピード
Nが10^6ならかなり早く出力される。(1000ms前後)
Nが10^7になると30秒ぐらいかかった。
他の方のコードだと10^7でも10秒ぐらいで出力されたので単純に力不足。。。
Atcoderの問題を見たら言い換えなど含めて、10^5ぐらいになるように設定されているイメージがあるので問題はなさそう。(ただし、整数問題を解けるようにならないとライブラリ作っても意味がない。。。
勉強になったこと
アルゴリズムとか以前にテストをしっかりしないといけない。
境界値テストやエッジケースへの対応の勉強にもなった。