はじめに
pythonにはitertoolsがあるため、組合せなどが非常に簡単に書けるが、numbaでは多分使えない。
ので自分で書いてみた。
何かを参考にしたりもしていないので、より良い方法はpython以外ならいくらでも転がっているのかもしれない。
以下ではprint(use)
の部分に何らかの計算を差し込むことを想定。
何億通りも試すような場合を想定してyieldっぽく、
全ての配列を保存するようなことをせずに実装
また、実際に組合せを出力するのではなく、indexのarrを出力
現在のnumbaではyieldが使えないし、再帰処理を使うと書きにくそうなので、それらなしで。
yield,再帰処理で書きたい場合は以下が詳細に書いてある。
while True
を使って
良い点
- シンプル
- 繰り返す回数をあらかじめ計算しない
for
を使った方ではfor i in range(N):
でのNを先に計算しないといけないと思う。
これだとNがint64の上限に当たると計算できなくなる。
そんなことなかなか起きないけどnumba使いたいような状況だし問題になることも?
悪い点
- マルチスレッドと相性は悪そう。
- while Trueは怖い、ミスってずっと回ってそう
whileなので、直接マルチスレッドできない。ので下のを見るとupdateをスレッド数だけやっていてちょっと不細工
組み合わせ
いわゆる$ _{n}C _{r}$通りのやつ
def combination(n, r):
use = np.arange(r)
while True:
print(use)
done_plus = update_combination(use, n)
if not done_plus:
break
def update_combination(use, n):
r = use.shape[0]
for i in range(r - 1, -1, -1):
if use[i] != n - (r - i):
use[i] += 1
for j in range(1, r - i):
use[i + j] = use[i] + j
return True
return False
n, r = 5, 3
combination(n, r)
>>> [0 1 2], [0 1 3] ... [2 3 4]
上のより(可読性を犠牲にして)ちょっとでもforを削減しようとすると
ここまでする必要があるかはわかりませんが
def update_combination(use, n):
r = use.shape[0]
if use[r - 1] != n - 1:
use[r - 1] += 1
return True
else:
for i in range(r - 2, -1, -1):
if use[i] != n - (r - i):
use[i] += 1
if i != r - 2:
for j in range(1, r - i):
use[i + j] = use[i] + j
else:
use[i + 1] = use[i] + 1
return True
return False
マルチスレッドしようとすると不細工だが、
def multi_combination(n, r, num_threads):
for thread_id in prange(num_threads):
use = np.arange(r)
for _ in range(thread_id):
done_plus = update_combination(use, n)
if not done_plus:
break
while True:
print(use)
for _ in range(num_threads):
done_plus = update_combination(use, n)
if not done_plus:
break
if not done_plus:
break
この下のもおんなじ感じで使えると思う
順列
いわゆる$ _{n}P _{r}$通りのやつ
これだけちょっとChatGPTに力を借りた。
また、上のupdate_combination
を使っていて
全ての組み合わせを作ってそこから全ての並び替え($ _{n}P _{n}$通り)を行う形
ので、まずupdate_permutation_n_is_r
でupdateしようとして、できなければupdate_combination
を使う。
def permutation (n, r):
use = np.arange(r)
while True:
print(use)
done_plus = update_permutation(use, n)
if not done_plus:
break
def update_permutation(use, n):
done_update = update_permutation_n_is_r(use)
if not done_update:
done_plus = update_combination(use, n)
if not done_plus:
return False
return True
def update_permutation_n_is_r(use):
n_r = use.shape[0]
i = n_r - 2
while i >= 0 and use[i] >= use[i + 1]:
i -= 1
if i < 0:
use[:] = use[::-1]
return False
j = n_r - 1
while use[j] <= use[i]:
j -= 1
use[i], use[j] = use[j], use[i]
use[i + 1:] = use[i + 1:][::-1]
return True
n, r = 5, 3
permutation (n, r)
>>> [0 1 2], [0 2 1] ... [4 3 2]
重複を許す組合せ
いわゆる$ _{n+r-1}C _{r} = _{n}H _{r}$通りのやつ
def repeated_combination(n, r):
use = np.zeros(r, dtype="int64")
while True:
print(use)
done_plus = update_repeated_combination(use, n, r)
if not done_plus:
break
def update_repeated_combination(use, n):
r = use.shape[0]
for i in range(r - 1, -1, -1):
if use[i] != n - 1:
use[i] += 1
for j in range(1, r - i):
use[i + j] = use[i]
return True
return False
n, r = 5, 3
repeated_combination(n, r)
>>> [0 0 0], [0 0 1] ... [4 4 4]
全パターン
例えばpatterns = np.array([4,2,3])
なら
- 要素0は0~3の4パターン
- 要素1は0~1の2パターン
- 要素2は0~2の3パターン
の全ての組み合わせなので、全部で$4\times 2\times 3=24$パターン
def all_pattern(patterns):
r = len(patterns)
use = np.zeros(r, dtype="int64")
while True:
print(use)
done_plus = update_all_pattern(use, patterns)
if not done_plus:
break
def update_all_pattern(use, patterns):
r = use.shape[0]
for i in range(r - 1, -1, -1):
if use[i] != patterns[i] - 1:
use[i] += 1
use[i + 1 :] = 0
return True
return False
patterns = np.array([4,2,3])
all_pattern(patterns)
>>> [0 0 0], [0 0 1] ... [3 1 2]