LoginSignup
1
0

More than 1 year has passed since last update.

pythonのmaltiprocessingでlambda関数みたいなものを使う方法

Last updated at Posted at 2022-12-10

やりたいこと

  • pythonのmultiprocessing(Pool)を使って並列化計算をしたい。
  • poolに関数を与えるときに、渡す関数の引数を調整するのにlambda関数を使いたい。

課題感

しかし、multiprocessingではlambda関数を渡せない。。

動かないコード
from multiprocessing import Pool

# 並列化したい関数
def tgt_func(a, b, c):
    return a + b + c + 1

def main():
    a_list = [1, 2, 3, 4, 5, 6, 7, 8]
    
    # 並列化したい処理(逐次実行版)
    # result = [tgt_func(a, b=10, c=10) for a in a_list]
    
    # 並列処理版
    num_parallel = 8
    with Pool(num_parallel) as p:
        imap = p.imap(lambda a: tgt_func(a, b=10, c=10), a_list)  # ← ここで "lambda's can't be pickled" とエラーがでる
        result = list(imap)

    print(result)

解決方法

次のような実行可能(callable)なクラスをつくって渡す。

from multiprocessing import Pool

# 並列化したい関数
def tgt_func(a, b, c):
    return a + b + c + 1

# 疑似lambda関数
class PseudoLambda():
    def __init__(self, b, c):
        self.b = b
        self.c = c
    def __call__(self, a):
        return tgt_func(a, b=self.b, c=self.c)

def main():
    a_list = [1, 2, 3, 4, 5, 6, 7, 8]
    
    # 並列化したい処理(逐次実行版)
    # result = [tgt_func(a, b=10, c=10) for a in a_list]
    
    # 並列処理版
    num_parallel = 8
    with Pool(num_parallel) as p:
        imap = p.imap(PseudoLambda, a_list) # lambdaで書いていたところを作ったclassで置き換えするだけ
        result = list(imap)

    print(result)

もう少しだけ発展的な方法

次のように、基本的な動作を全てクラスに追加して抽象化すれば、コードを再利用しやすい。

from multiprocessing import Pool
from tqdm import tqdm

# 並列化したい関数
def func(a, b, c):
    return a + b + c + 1

# 並列計算機
class ParallelCalculator():
    def __init__(self, func, **fixed_args):
        self.func = func                    # 実行する関数
        self.fixed_args = fixed_args        # 固定値で使用する引数たち

    def __call__(self, arg):
        return self.func(arg, **self.fixed_args)
        
    def exe(self, num_parallel, args_list):
        with Pool(num_parallel) as p:
            imap = p.imap(self, args_list)
            result = list(tqdm(imap))       # tqdmを使ってプログレスバーも表示
        return result

def main():
    a_list = [1, 2, 3, 4, 5, 6, 7, 8]
    
    # 並列化したい処理(逐次実行版)
    # result = [func(a, b=10, c=10) for a in a_list]
    
    # 並列処理版
    pc = ParallelCalculator(tgt_func, b=10, c=10)
    result = pc.exe(num_parallel=8, a_list)

    print(result)

おわりに

  • とりあえず何とかできた。
  • 並列計算で共通で使用する引数(今回のコードの場合、変数bとかc)が大きいオブジェクトの場合、引数をコピーしないようにrayというライブラリを使ったほうがいいかも。
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0