1
0

More than 1 year has passed since last update.

【Python】AWS Lambdaでマルチプロセス処理を実装する(ProcessPoolExecutorの代替)

Last updated at Posted at 2023-03-09

はじめに

AWS Lambda関数にてvCPU割り当てを増やして並列処理をしようと画策したのですが、沼にはまりかけてしまったのでここにメモ。

サマリ

  1. Pythonはマルチスレッドでは高速化できない
  2. AWS Lambda上でのプロセス間通信には共有メモリが使えない
  3. pipeのバッファがあふれる(64KiBを超える)想定のサンプルが見つからない
  4. ProcessPoolExecutorと同じように使える関数の並列処理呼び出しモジュールを作った

1. Pythonはマルチスレッドでは高速化できない

Pythonではマルチスレッド(並行処理)とマルチプロセス(並列処理)は明確に区別されます。1プロセス中でアクティブなスレッドは常に一つの原則(GIL)があるため、複数のCPUコアを使用して時間当たりの処理量を稼ぎたい用途ではマルチスレッドは役に立たちません(IO待ちのブロッキングを回避したい用途では有効)。よって、マルチプロセス処理(ProcessPoolExecutorなど)を使用する必要があります。

2. AWS Lambda上でのプロセス間通信には共有メモリが使えない

先に紹介した関数の並列実行のためのモジュールProcessPoolExecutorですが、AWS Lambdaでは使用できません。

[ERROR] OSError: [Errno 38] Function not implemented ...

ProcessPoolExecutorは共有メモリを介した親子プロセス間通信を行いますが、普通のLinuxに存在している/dev/shm(共有メモリのデバイスファイル)がLambda関数の処理系には存在しない(2023年3月現在)のが原因です。共有メモリに代わってPipeを用いたプロセス間通信を行うことで対応できますが、ProcessPoolExecutor的な使い方ができる処理を自作する必要が出てきました。

参考:Lambda環境の/dev/配下のファイル一覧
{
  [
      "/dev/random",
      "/dev/stderr",
      "/dev/full",
      "/dev/zero",
      "/dev/urandom",
      "/dev/null",
      "/dev/stdout",
      "/dev/stdin"
  ]
}

3. Pipeのバッファがあふれる(64KiBを超える)想定のサンプルが見つからない

Pipeによる関数並列実行方法のサンプルはAWS公式ブログ含めいくつか見つけられます。ところがいざ使ってみると処理が止まる、そしていつも同じ場所で止まる。親が悪いのか子が悪いのか、マルチプロセス処理はデバッグが面倒なのもあって地味に時間を費やしてしまいました。
結論、今回並列実行させていた処理のreturnのデータサイズがpipeのバッファ(64KiBの模様)を超えていたため、子プロセス側が(親がバッファを刈り取るまで)pipeへの書き込み待ちで止まってしまっていたのが原因でした。しかし、AWS公式ブログで紹介されているサンプルも含めてpipeのバッファサイズ(64kB)があふれることを想定した例が見つけられず、ならば作ってしまえ、となった次第です。

子プロセスからのreturnをpipeから分割して刈り取る処理(子プロセスは1つ)

ソケットプログラミングでは定番ですが、ファイルディスクリプタというかpipeにデータが書き込まれるのをひたすらwaitですね。
子プロセスがexitしていれば(追加でpipeに書き込まれることはないので)、受信したpickledされたデータを元に戻して終了。

import pickle
import multiprocessing

# 並列実行させたい関数
def _my_job(arg1, arg2):
    ret = arg1 * arg2
    return ret
def _my_job_tuple(args):
    arg1, arg2 = args
    ret = arg1 * arg2
    return ret

def parallel_run(fn, *iterables):
    # return回収先
    recv_data = bytes()
    
    # 受信用Pipe、送信用Pipe作成
    pipe_recv, pipe_send = multiprocessing.Pipe(False)
    
    # 実行したい関数と引数(のオブジェクトデータ)をPipeで子プロセスに渡す
    worker = lambda _fn, _args, _conn: _conn.send(_fn(*_args))
    process = multiprocessing.Process(target=worker, args=(fn, iterables, pipe_send))
 
    # 子プロセス実行
    process.start()

    while(True):
        # 10ミリ秒周期でPipeへの書き込み(子プロセスのreturnデータ)をpolling
        # この例では_ret_connsは0~1要素
        _ret_conns = multiprocessing.connection.wait([ pipe_recv])
        for _conn in _ret_conns:
            # returnデータをバイナリで受信
            recv_data += _conn.recv_bytes()
        
        if process.exitcode is None:
            # 子プロセス実行中はexitcodeはNone
            continue
        else:
            # 子プロセス終了。int型の終了コードが取得可能
            break

    # 子プロセスのreturnはpickleされたオブジェクトのバイナリデータ
    # これをunpickle
    return pickle.loads(recv_data)

if __name__ == "__main__":
    ret = parallel_run(_my_job, ["hogehoge"], 10000)
    print(len(ret))
    print(ret[0])

    ret = parallel_run(_my_job_tuple, (["fugafuga"], 10000) )
    print(len(ret))
    print(ret[0])

4. ProcessPoolExecutorと同じ感覚で使える関数の並列処理呼び出しモジュール

いろいろ加飾して実用になるようにしたもの。つい何となくコメントを英語で書いてしまった。。。

parallel_run.py
# -*- coding: utf-8 -*-
import os
import pickle
import time
import multiprocessing
from typing import Callable
from collections.abc import Iterable

POLLING_CYCLE_SEC = 0.01

class _task_ctx():
    def __init__(self, fn:Callable, args:tuple, seq_no:int):
        worker = lambda _fn, _args, _conn: _conn.send(_fn(*_args))

        self._pipe_recv, self._pipe_send = multiprocessing.Pipe(False)
        self._process = multiprocessing.Process(target=worker, args=(fn, args, self._pipe_send))

        self._seq_no = seq_no
        return
    
    @property
    def pipe_recv(self):
        return self._pipe_recv
    
    @property
    def fileno_recv(self):
        return self._pipe_recv.fileno()
    
    @property
    def seq_no(self):
        return self._seq_no
    
    def start(self):
        self._process.start()
    
    def terminate(self):
        self._process.terminate()
        
    def get_exitcode(self):
        return self._process.exitcode
    
    def close(self):
        self._process.join()
        self._pipe_send.close()
        self._pipe_recv.close()
        self._process.close()
        
        self._seq_no = None
        return


def do(fn:Callable, *iterables:Iterable, max_workers:int=None, timeout=None) -> list:    
    
    """ Returns an iterator equivalent to map(fn, iter) using Pipe (means NOT use shared memory)
        For example, AWS lambda function does NOT allowd to access /dev/shm,
        then cannot use Pool (like ProcessPoolExcecutor) for IPC. This function is
        expected to use such a situation.

    Parameters
    ----------
    fn : Callable
        A callable that will take as many arguments as there are passed iterables.

    max_workers : int, optional
        The maximum number of processes that can be used to execute the given calls.
        If None, then max(1, (os.cpu_count() - 1)) process will run.
    
    timeout: The maximum number of seconds to wait. If None, then there
        is no limit on the wait time.
        

    Raises
    ------
    TypeError :
        Invalid fn or iterables.
    TimeoutError : 
        If timeout is set and expired it.

    Returns
    -------
    list
        return values sorted corresponding to args order.
        if fn raised an exception, None will be set instead of the return value.
    """
    # parallel jobs
    max_workers = max(1, os.cpu_count()-1) if max_workers is None else max_workers
    time_limit  = None if timeout is None else (time.time() + timeout)
    
    proc_contexts  = dict()
    recv_data = dict()

    if not isinstance(fn, Callable):
        raise TypeError("fn is not function")

    if not all([ isinstance(_x, Iterable) for _x in iterables]):
        raise TypeError("Non Itrables element in iterables")

    args_list = list(zip(*iterables))

    is_timed_out = False
    seq_no = 0
    while(True):
        # create new child process if child processes are less than max_workers.
        if (not is_timed_out) and (len(proc_contexts) < max_workers) and (0 < len(args_list)):
            # get arguments of an individual runs
            _arg = args_list.pop(0)
            
            # create child process context and register to "proc_contexts"
            try:
                _ctx = _task_ctx(fn, _arg, seq_no)
            except Exception as e:
                print(f"Exception: {str(e)}")
                continue
            #print(_ctx.fileno)
            proc_contexts[_ctx.fileno_recv] = _ctx
            
            # initialize received data are as bytes
            # (fileno will be reused by OS after current process is gone.
            # Therefore NOT to manage the result using fileno)
            recv_data[seq_no] = bytes()
            
            # start child process
            _ctx.start()
            
            seq_no += 1
            continue
        
        # wait until timer expired
        # (NOT return even if some socket goes ready to receive. Then "timeout" must set
        # bigger than zero but smaller as possible.)
        _ret_conns = multiprocessing.connection.wait(
                        [ _ctx.pipe_recv for _ctx in proc_contexts.values() ],
                        timeout=POLLING_CYCLE_SEC )
        for _conn in _ret_conns:
            _ctx:_task_ctx = proc_contexts[_conn.fileno()]
            # receive return value as bytes (the return data is pickled)
            try:
                recv_data[_ctx.seq_no] += _conn.recv_bytes()
            except EOFError:
                # No more receivable data (may be peer socket was closed)
                pass
            except OSError:
                # never reach here!
                pass
        
        # search joinable processes
        for _ctx in list(proc_contexts.values()):
            if _ctx.get_exitcode() is not None:
                # pop out context from proc_contexts
                # join and close resources
                try:
                    proc_contexts.pop(_ctx.fileno_recv)
                    _ctx.close()
                except Exception as e:
                    print(f"Exception: {str(e)}")
                    
                    
        
        # No more processes will create nor wait its done.
        if (0==len(args_list)) and (0==len(proc_contexts.keys())):
            break
        
        # if timed out, will terminate each processes
        if (time_limit is not None) and (not is_timed_out) and (time_limit < time.time()):
            is_timed_out = True
            args_list.clear()
            for _ctx in list(proc_contexts.values()):
                _ctx.terminate()
                
    
    # No returns if timed out and raise exception.
    if is_timed_out:
        raise TimeoutError
    
    # After all child process returned, Sort recv_data with key (means seq_no) ascending order
    # and unpickle the value (pickled bytes).    
    result_sorted = []
    for _, _v in sorted(recv_data.items()):
        try:
            result_sorted.append(pickle.loads(_v))
        except:
            result_sorted.append(None)
    return result_sorted

参考(Pythonのmultiprocessing)

参考(AWS Lambda関数のvCPU数)

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0