AWS Lambda で Python の multiprocessing を扱おうとしたら少しハマったのでメモしておく。
この記事内のコードはすべて AWS Lambda の Python 3.9 ランタイムで実行する想定で書かれています。
例題
例えば、以下のような Lambda 関数があるとする。
from time import sleep
def square(n: int) -> int:
"""1秒かけて入力値の2乗を返す関数"""
sleep(1)
return n ** 2
def handler(event, context):
return [square(i) for i in range(5)]
この関数の実行には5秒かかってしまうため、 square
関数を並列処理して実行時間を短縮したい。
asyncio 版
まずはシンプルにイベントループで非同期処理できないか考える。
Python の非同期処理と言えば asyncio
1 が便利。
from asyncio import get_event_loop, gather, sleep
async def square(n: int) -> int:
"""1秒かけて入力値の2乗を返す関数"""
await sleep(1)
return n ** 2
async def async_handler(event, context):
return await gather(*[square(i) for i in range(5)])
def handler(event, context):
return get_event_loop().run_until_complete(async_handler(event, context))
この Lambda 関数は問題なく動く。
今回のケースでは time.sleep
を asyncio.sleep
に置き換えることができたので、時間がかかる処理を非同期処理として実行することができた。
しかし必ずしも時間がかかる処理が非同期処理として実行できるわけではなく、どうしても時間がかかるブロッキングな処理を並列で実行したいときもあるかもしれない。もしくは CPU バウンドな処理をする際に Lambda ランタイムの複数コアをフル活用したいかもしれない。
以降ではそのようなケースを考えていく。
multiprocessing.Pool 版 (※ 動かない)
multiprocessing.Pool
2 で並列処理できるように直してみる。
from time import sleep
from multiprocessing import Pool
def square(n: int) -> int:
"""1秒かけて入力値の2乗を返す関数"""
sleep(1)
return n ** 2
def handler(event, context):
with Pool(5) as p:
return p.map(square, range(5))
しかしこの Lambda 関数は動かない。
Lambda のランタイムでは共有メモリ (/dev/shm
) を使用することができないため、 multiprocessing
のうち Pool
や Queue
などを使うと Function not implemented
というエラーが発生してしまう。3
multiprocessing.dummy.Pool
という threading
のラッパーもあるが、これについても同様にエラーが発生してしまった。
multiprocessing.Process 版
地道に (?) multiprocessing.Process
でプロセスを作って multiprocessing.Pipe
で結果を受け取るようにしてみる。
from time import sleep
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
def square(i: int, conn: Connection):
"""1秒かけて入力値の2乗を返す関数"""
sleep(1)
conn.send(i ** 2)
def handler(event, context):
connections = []
processes = []
for i in range(5):
conn_recv, conn_send = Pipe(False)
process = Process(target=square, args=(i, conn_send))
process.start()
connections.append(conn_recv)
processes.append(process)
return [conn.recv() for conn in connections]
この関数は問題なく動く。
しかし handler
の実装がやや煩雑になってしまった気がするし、 square
内で Connection
オブジェクトを扱わなければならないのも気に入らない。
multiprocessing.Process 改良版
並列処理に関わる部分を parallel.py
に切り出した。
from multiprocessing import Pipe, Process
from typing import Callable
from collections.abc import Iterable
def parallel(task: Callable, args: Iterable) -> list:
connections = []
processes = []
worker = lambda task, arg, conn: conn.send(task(arg))
for arg in args:
conn_recv, conn_send = Pipe(False)
process = Process(target=worker, args=(task, arg, conn_send))
process.start()
connections.append(conn_recv)
processes.append(process)
return [conn.recv() for conn in connections]
from time import sleep
from parallel import parallel
def square(n: int) -> int:
"""1秒かけて入力値の2乗を返す関数"""
sleep(1)
return n ** 2
def handler(event, context):
return parallel(square, range(5))
なんだかいい感じになった気がする。