たくさんのI/Oバウンドワークロードを実行するために concurrent.futures.ThreadPoolExecutor
を使う方が多いかと思いますが、同時に実行できるタスク数を制限するパラメータはありません。システム要件として「最大 N タスクまで同時に実行可能」という風に実装する必要があった場合、セマフォという選択肢をおすすめしたいです。
セマフォは、限られた容量のリソースへのアクセスを制御するためによく使われるもので、Pythonの標準 threading
モジュールに Semaphore
と BoundedSemaphore
の2つの実装があります。
挙動を理解するために、簡単なコード例を見ましょう。
import threading
sem = threading.Semaphore(1)
print(sem._value) # 1
sem.acquire() # Trueを返す
print(sem._value) # 0
sem.acquire(timeout=5) # 5秒経過後にFalseを返す
sem.release()
print(sem._value) # 1
sem.release()
print(sem._value) # 2
上記で行われていることを順番に説明します。
- 初期値を
1
としたセマフォを作る -
acquire()
でロックの獲得に成功してTrue
を返す。同時に_value
から1
を引く。 - 再び
acquire()
でロックを獲得しようとする。しかし、_value
が0
を下回ることができないので、ロック獲得に失敗してFalse
を返す。 -
release()
でロックを一度解放する。_value
に1
を足す。 -
release()
でロックをもう一度解放する。_value
が初期値を超えて2
となる。
こうしてセマフォが「獲得回数と解放回数」を見ていてくれます。acquire()
でロックを獲得すると、戻り値が True
に、獲得できなかった場合、戻り値が False
になります。
では、有限セマフォと呼ばれる BoundedSemaphore
を使った場合、どうなるのでしょうか?
import threading
sem = threading.BoundedSemaphore(1)
print(sem._value) # 1
sem.acquire() # Trueを返す
print(sem._value) # 0
sem.acquire(timeout=5) # 5秒経過後にFalseを返す
sem.release()
print(sem._value) # 1
sem.release()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/ (省略) /threading.py", line 504, in release
raise ValueError("Semaphore released too many times")
ValueError: Semaphore released too many times
怒られました。
通常のセマフォは、現在の値が 0
を下回ることができない一方、初期値を超えても問題ありません。それに対して BoundedSemaphore
は、上限値も下限値もあります。解放回数が獲得回数を超えないことを保証したいとき、BoundedSemaphore
が味方になってくれます。
さて、threading.BoundedSemaphore
を使って concurrent.futures.ThreadPoolExecutor
の同時実行可能なタスク数を制御してみましょう。
import logging
from concurrent.futures import ThreadPoolExecutor
from threading import BoundedSemaphore
# 同時実行可能なタスク数
MAX_CONCURRENT_TASKS = 50
# タスク数(仮)
TASK_COUNT = 500
# ロックを獲得するのに最大何秒待つか
ACQUIRE_WAIT = 3
# I/Oを行うタスク関数
def io_bound_workload():
pass
if __name__ == "__main__":
sem = BoundedSemaphore(value=MAX_CONCURRENT_TASKS)
count_ok, count_err, count_done = 0, 0, 0
with ThreadPoolExecutor() as pool:
while count_done < TASK_COUNT:
if sem.acquire(timeout=ACQUIRE_WAIT):
logging.info("Lock acquired")
try:
future = pool.submit(io_bound_workload)
future.add_done_callback(lambda _: sem.release())
count_ok += 1
except Exception as e:
sem.release()
count_err += 1
count_done = count_ok + count_err
else:
logging.debug("Waiting for other jobs to finish..")
いいですね。これで、同時実行タスク数が「50」を超えない程度で、500あるタスクを処理していくことができます。どうですか?スクレーピング、APIのレートリミット回避、システム負荷の軽減など、いろいろなところで役に立ちそうですよね?