1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

itertoolsを使わずに組合せなどをしたい

Last updated at Posted at 2025-04-15

はじめに

pythonにはitertoolsがあるため、組合せなどが非常に簡単に書けるが、numbaでは多分使えない。
ので自分で書いてみた。
何かを参考にしたりもしていないので、より良い方法はpython以外ならいくらでも転がっているのかもしれない。
以下ではprint(use)の部分に何らかの計算を差し込むことを想定。
何億通りも試すような場合を想定してyieldっぽく、
全ての配列を保存するようなことをせずに実装
また、実際に組合せを出力するのではなく、indexのarrを出力
現在のnumbaではyieldが使えないし、再帰処理を使うと書きにくそうなので、それらなしで。

yield,再帰処理で書きたい場合は以下が詳細に書いてある。

while Trueを使って

良い点

  • シンプル
  • 繰り返す回数をあらかじめ計算しない

forを使った方ではfor i in range(N):でのNを先に計算しないといけないと思う。
これだとNがint64の上限に当たると計算できなくなる。
そんなことなかなか起きないけどnumba使いたいような状況だし問題になることも?

悪い点

  • マルチスレッドと相性は悪そう。
  • while Trueは怖い、ミスってずっと回ってそう
    whileなので、直接マルチスレッドできない。ので下のを見るとupdateをスレッド数だけやっていてちょっと不細工

組み合わせ

いわゆる$ _{n}C _{r}$通りのやつ

def combination(n, r):
    use = np.arange(r)
    while True:
        print(use)
        done_plus = update_combination(use, n)
        if not done_plus:
            break

def update_combination(use, n):
    r = use.shape[0]
    for i in range(r - 1, -1, -1):
        if use[i] != n - (r - i):
            use[i] += 1
            for j in range(1, r - i):
                use[i + j] = use[i] + j
            return True
    return False
n, r = 5, 3
combination(n, r)
>>> [0 1 2], [0 1 3] ... [2 3 4]

上のより(可読性を犠牲にして)ちょっとでもforを削減しようとすると
ここまでする必要があるかはわかりませんが

def update_combination(use, n):
    r = use.shape[0]
    if use[r - 1] != n - 1:
        use[r - 1] += 1
        return True
    else:
        for i in range(r - 2, -1, -1):
            if use[i] != n - (r - i):
                use[i] += 1
                if i != r - 2:
                    for j in range(1, r - i):
                        use[i + j] = use[i] + j
                else:
                    use[i + 1] = use[i] + 1
                return True
        return False

マルチスレッドしようとすると不細工だが、

def multi_combination(n, r, num_threads):
    for thread_id in prange(num_threads):
        use = np.arange(r)
        for _ in range(thread_id):
            done_plus = update_combination(use, n)
            if not done_plus:
                break
    
        while True:
            print(use)
            for _ in range(num_threads):
                done_plus = update_combination(use, n)
                if not done_plus:
                    break
            if not done_plus:
                break

この下のもおんなじ感じで使えると思う

順列

いわゆる$ _{n}P _{r}$通りのやつ
これだけちょっとChatGPTに力を借りた。
また、上のupdate_combinationを使っていて
全ての組み合わせを作ってそこから全ての並び替え($ _{n}P _{n}$通り)を行う形
ので、まずupdate_permutation_n_is_rでupdateしようとして、できなければupdate_combinationを使う。

def permutation (n, r):
    use = np.arange(r)
    while True:
        print(use)
        done_plus = update_permutation(use, n)
        if not done_plus:
            break

def update_permutation(use, n):
    done_update = update_permutation_n_is_r(use)
    if not done_update:
        done_plus = update_combination(use, n)
        if not done_plus:
            return False
    return True

def update_permutation_n_is_r(use):
    n_r = use.shape[0]
    i = n_r - 2
    while i >= 0 and use[i] >= use[i + 1]:
        i -= 1
    if i < 0:
        use[:] = use[::-1]
        return False
    j = n_r - 1
    while use[j] <= use[i]:
        j -= 1
    use[i], use[j] = use[j], use[i]
    use[i + 1:] = use[i + 1:][::-1]
    return True
n, r = 5, 3
permutation (n, r)
>>> [0 1 2], [0 2 1] ... [4 3 2]

重複を許す組合せ

いわゆる$ _{n+r-1}C _{r} = _{n}H _{r}$通りのやつ

def repeated_combination(n, r):
    use = np.zeros(r, dtype="int64")
    while True:
        print(use)
        done_plus = update_repeated_combination(use, n, r)
        if not done_plus:
            break

def update_repeated_combination(use, n):
    r = use.shape[0]
    for i in range(r - 1, -1, -1):
        if use[i] != n - 1:
            use[i] += 1
            for j in range(1, r - i):
                use[i + j] = use[i]
            return True
    return False
n, r = 5, 3
repeated_combination(n, r)
>>> [0 0 0], [0 0 1] ... [4 4 4]

全パターン

例えばpatterns = np.array([4,2,3])なら

  • 要素0は0~3の4パターン
  • 要素1は0~1の2パターン
  • 要素2は0~2の3パターン
    の全ての組み合わせなので、全部で$4\times 2\times 3=24$パターン
def all_pattern(patterns):
    r = len(patterns)
    use = np.zeros(r, dtype="int64")
    while True:
        print(use)
        done_plus =  update_all_pattern(use, patterns)
        if not done_plus:
            break

def update_all_pattern(use, patterns):
    r = use.shape[0]
    for i in range(r - 1, -1, -1):
        if use[i] != patterns[i] - 1:
            use[i] += 1
            use[i + 1 :] = 0
            return True
    return False
patterns = np.array([4,2,3])
all_pattern(patterns)
>>> [0 0 0], [0 0 1] ... [3 1 2]
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?