マルチプロセスの学習で苦戦した点と解決したときのコードをまとめました。
まだ学習中のため記載に誤りなどはある場合は指摘をいただけると嬉しいです。
やろうとしたこと
マルチプロセスの学習として以下の課題を設けて実行しようとした。
- DBからデータを取得(顧客情報のダミーデータを3000件用意)
- データから住所を抽出し、どの都道府県の顧客が多いかをカウントして出力
コード
main.py
import multiprocessing
import re
import time
import random
# データ取得用外部モジュール(割愛)
import sql_test
pattern = "[^都道府県]*[都道府県]"
#マルチプロセスで実行するタスク
def process(lock,dict,user):
# 処理の遅延を疑似的に再現するだけのsleep
time.sleep(random.uniform(0.001, 0.01))
#正規表現で都道府県を取得、カウントを増やす。
res = re.match(pattern,str(user["ADDRESS"])).group()
with lock:
if dict.get(res) == None:
dict[res] = 1
else:
dict[res] += 1
# メイン処理
if __name__ == "__main__":
# 外部モジュールから顧客データを取得
fetch = sql_test.execute()
# lock と 共有データの辞書型配列を渡すためにmagagerを実行
with multiprocessing.Manager() as manager:
# データを安全に更新するためのロック
lock = manager.Lock()
# データ共有用の辞書型配列
dict = manager.dict()
# CPUのコア数に基づいて最適なプロセス数を決定
num_processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=num_processes) as pool:
# lock , task を渡すために for でリスト化する。
tasks = [(lock,dict,fetch[i] ) for i in range(len(fetch))]
# リスト化したタスクを引数に渡して実行
# タスクがすべて完了するまで待機する。
pool.starmap(process,tasks)
# 処理結果を表示
print(dict)