Pythonで並列処理するときのコードの雛形です。タスクを何件処理したかの進捗も表示します。自身の備忘録用に投稿しておきます
import dataclasses
import logging
import random
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import List
# ログ設定
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_format = logging.Formatter("%(asctime)s [%(levelname)8s] %(message)s")
# 標準出力へのハンドラ
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(log_format)
logger.addHandler(stdout_handler)
# ログファイルへのハンドラ
file_handler = logging.FileHandler("hoge.log", "a+")
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
# 進捗標示用カウンター
SUCCESS_COUNTER_LOCK = threading.Lock()
SUCCESS_COUNTER = 0
ERROR_COUNTER_LOCK = threading.Lock()
ERROR_COUNTER = 0
@dataclasses.dataclass
class WorkerInput:
"""ワーカーの入力"""
# TODO : 実際のタスクに即して変更
id: int
@dataclasses.dataclass
class WorkerOutput:
"""ワーカーの出力"""
# TODO : 実際のタスクに即して変更
is_success: bool
class Worker:
@staticmethod
def run(input: WorkerInput) -> WorkerOutput:
try:
# TODO : 実際のタスクに即して変更 ここから
time.sleep(10)
# たまに失敗させる
if random.random() < 0.5:
raise Exception("Random Error")
logger.info(f"success. {input=}")
# TODO: ここまで
with SUCCESS_COUNTER_LOCK:
global SUCCESS_COUNTER
SUCCESS_COUNTER += 1
return WorkerOutput(is_success=True)
except Exception as e:
logger.warning(f"error. {input=} err={e}")
with ERROR_COUNTER_LOCK:
global ERROR_COUNTER
ERROR_COUNTER += 1
return WorkerOutput(is_success=False)
if __name__ == "__main__":
worker_num = 3
# TODO: 実際のタスクに即して変更
input_list = [WorkerInput(id=i) for i in range(10)]
input_num = len(input_list)
with ThreadPoolExecutor(max_workers=worker_num, thread_name_prefix="thread") as executor:
futures = []
for input in input_list:
# ワーカーを起動
futures.append(executor.submit(Worker.run, input))
# 進捗表示ループ
while True:
# 進捗表示
logger.info(f"success={SUCCESS_COUNTER} error={ERROR_COUNTER} total={input_num}")
# すべてのワーカーが終了したら進捗標示ループを抜ける
if SUCCESS_COUNTER + ERROR_COUNTER == input_num:
break
time.sleep(10)
# ワーカーの出力を受け取る
output_list: List[WorkerOutput] = [future.result() for future in futures]
logger.info(f"{output_list=}")
出力
2024-09-05 16:09:16,001 [ INFO] success=0 error=0 total=10
2024-09-05 16:09:26,004 [ INFO] success=0 error=0 total=10
2024-09-05 16:09:26,005 [ INFO] success. input=WorkerInput(id=2)
2024-09-05 16:09:26,007 [ INFO] success. input=WorkerInput(id=0)
2024-09-05 16:09:26,007 [ WARNING] error. input=WorkerInput(id=1) err=Random Error
2024-09-05 16:09:36,007 [ INFO] success=2 error=1 total=10
2024-09-05 16:09:36,008 [ INFO] success. input=WorkerInput(id=5)
2024-09-05 16:09:36,008 [ WARNING] error. input=WorkerInput(id=3) err=Random Error
2024-09-05 16:09:36,008 [ INFO] success. input=WorkerInput(id=4)
2024-09-05 16:09:46,011 [ INFO] success. input=WorkerInput(id=6)
2024-09-05 16:09:46,011 [ WARNING] error. input=WorkerInput(id=7) err=Random Error
2024-09-05 16:09:46,012 [ INFO] success=4 error=2 total=10
2024-09-05 16:09:46,012 [ WARNING] error. input=WorkerInput(id=8) err=Random Error
2024-09-05 16:09:56,015 [ INFO] success. input=WorkerInput(id=9)
2024-09-05 16:09:56,019 [ INFO] success=6 error=4 total=10
2024-09-05 16:09:56,020 [ INFO] output_list=[WorkerOutput(is_success=True), WorkerOutput(is_success=False), WorkerOutput(is_success=True), WorkerOutput(is_success=False), WorkerOutput(is_success=True), WorkerOutput(is_success=True), WorkerOutput(is_success=True), WorkerOutput(is_success=False), WorkerOutput(is_success=False), WorkerOutput(is_success=True)]
コードの説明
# ログ設定
logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_format = logging.Formatter("%(asctime)s [%(levelname)8s] %(message)s")
# 標準出力へのハンドラ
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(log_format)
logger.addHandler(stdout_handler)
# ログファイルへのハンドラ
file_handler = logging.FileHandler("hoge.log", "a+")
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
ログの設定です。並列処理とは関係ないので割愛します
# 進捗標示用カウンター
SUCCESS_COUNTER_LOCK = threading.Lock()
SUCCESS_COUNTER = 0
ERROR_COUNTER_LOCK = threading.Lock()
ERROR_COUNTER = 0
並列処理する入力の全体の数に対して、成功した数(SUCCESS_COUNTER)と失敗した数(ERROR_COUNTER)を記録します。
これらの変数は、並行して動作しているワーカースレッドから同時に書き込まれるため排他制御する必要があります。そのために、ロックを用意しています。このロックを取得したワーカースレッドのみ、カウンターの値を書き換えることができます。
@dataclasses.dataclass
class WorkerInput:
"""ワーカーの入力"""
# TODO : 実際のタスクに即して変更
id: int
@dataclasses.dataclass
class WorkerOutput:
"""ワーカーの出力"""
# TODO : 実際のタスクに即して変更
is_success: bool
並列処理するワーカーの入力と出力を、データクラスで定義します。実際にはここに処理したい入力や出力を定義知るのですが、このコードは雛形なので仮のフィールドを入れています。
class Worker:
@staticmethod
def run(input: WorkerInput) -> WorkerOutput:
try:
# TODO : 実際のタスクに即して変更 ここから
time.sleep(10)
# たまに失敗させる
if random.random() < 0.5:
raise Exception("Random Error")
logger.info(f"success. {input=}")
# TODO: ここまで
with SUCCESS_COUNTER_LOCK:
global SUCCESS_COUNTER
SUCCESS_COUNTER += 1
return WorkerOutput(is_success=True)
except Exception as e:
logger.warning(f"error. {input=} err={e}")
with ERROR_COUNTER_LOCK:
global ERROR_COUNTER
ERROR_COUNTER += 1
return WorkerOutput(is_success=False)
ワーカクラスです。このクラスの関数run()が並列で実行されます。正常終了したときは最後にSUCCESS_COUNTERを+1していますが、例外が発生したときにはERROR_COUNTを+1します。
if __name__ == "__main__":
worker_num = 3
# TODO: 実際のタスクに即して変更
input_list = [WorkerInput(id=i) for i in range(10)]
input_num = len(input_list)
# ワーカーを起動
with ThreadPoolExecutor(max_workers=worker_num, thread_name_prefix="thread") as executor:
futures = []
for input in input_list:
# ワーカーを起動
futures.append(executor.submit(Worker.run, input))
メイン関数の前半です。 executor.submit(Worker.run, input)
がポイントで、 Worker.run()
関数に引数 input
を指定して実行します。ここで実行されると、新たにスレッドが生成され、そのスレッドでWorker.run()
が非同期で実行されます。executor.submit()
の戻り値は 並列タスク実行クラス「feature」 であり、これを保持しておくことでワーカーの出力を取り出すことができます。ここでは features
という配列に格納しています。
# 進捗表示ループ
while True:
# 進捗表示
logger.info(f"success={SUCCESS_COUNTER} error={ERROR_COUNTER} total={input_num}")
# すべてのワーカーが終了したら進捗標示ループを抜ける
if SUCCESS_COUNTER + ERROR_COUNTER == input_num:
break
time.sleep(10)
メイン関数の中盤です。ここまで来たときには、すでに裏でワーカーが並列で実行し始めているので、あとはワーカーが処理を終えるのを待つだけです。なので、無限ループを行い、SUCCESS_COUNTERとERROR_COUNTERの合計値が入力の数と一致するまで待ちます。
# ワーカーの結果を取得する
output_list: List[WorkerOutput] = [future.result() for future in futures]
logger.info(f"{output_list=}")
最後に future.result()
を実行し、ワーカーの出力を受け取ります