やりたいこと
- pythonのmultiprocessing(Pool)を使って並列化計算をしたい。
- poolに関数を与えるときに、渡す関数の引数を調整するのにlambda関数を使いたい。
課題感
しかし、multiprocessingではlambda関数を渡せない。。
動かないコード
from multiprocessing import Pool
# 並列化したい関数
def tgt_func(a, b, c):
return a + b + c + 1
def main():
a_list = [1, 2, 3, 4, 5, 6, 7, 8]
# 並列化したい処理(逐次実行版)
# result = [tgt_func(a, b=10, c=10) for a in a_list]
# 並列処理版
num_parallel = 8
with Pool(num_parallel) as p:
imap = p.imap(lambda a: tgt_func(a, b=10, c=10), a_list) # ← ここで "lambda's can't be pickled" とエラーがでる
result = list(imap)
print(result)
解決方法
次のような実行可能(callable)なクラスをつくって渡す。
from multiprocessing import Pool
# 並列化したい関数
def tgt_func(a, b, c):
return a + b + c + 1
# 疑似lambda関数
class PseudoLambda():
def __init__(self, b, c):
self.b = b
self.c = c
def __call__(self, a):
return tgt_func(a, b=self.b, c=self.c)
def main():
a_list = [1, 2, 3, 4, 5, 6, 7, 8]
# 並列化したい処理(逐次実行版)
# result = [tgt_func(a, b=10, c=10) for a in a_list]
# 並列処理版
num_parallel = 8
with Pool(num_parallel) as p:
imap = p.imap(PseudoLambda, a_list) # lambdaで書いていたところを作ったclassで置き換えするだけ
result = list(imap)
print(result)
もう少しだけ発展的な方法
次のように、基本的な動作を全てクラスに追加して抽象化すれば、コードを再利用しやすい。
from multiprocessing import Pool
from tqdm import tqdm
# 並列化したい関数
def func(a, b, c):
return a + b + c + 1
# 並列計算機
class ParallelCalculator():
def __init__(self, func, **fixed_args):
self.func = func # 実行する関数
self.fixed_args = fixed_args # 固定値で使用する引数たち
def __call__(self, arg):
return self.func(arg, **self.fixed_args)
def exe(self, num_parallel, args_list):
with Pool(num_parallel) as p:
imap = p.imap(self, args_list)
result = list(tqdm(imap)) # tqdmを使ってプログレスバーも表示
return result
def main():
a_list = [1, 2, 3, 4, 5, 6, 7, 8]
# 並列化したい処理(逐次実行版)
# result = [func(a, b=10, c=10) for a in a_list]
# 並列処理版
pc = ParallelCalculator(tgt_func, b=10, c=10)
result = pc.exe(num_parallel=8, a_list)
print(result)
おわりに
- とりあえず何とかできた。
- 並列計算で共通で使用する引数(今回のコードの場合、変数bとかc)が大きいオブジェクトの場合、引数をコピーしないようにrayというライブラリを使ったほうがいいかも。