昨今はPythonモジュールや機械学習のウェイトの読み込みが遅くなってきて開発効率が悪くなってきたので何か良い仕組み作れないかなと色々手探りしてみました。
とりあえずクラサバ式にしようと思い、IPCや共有メモリを導入しようと思ったのですが、途中でプリロードしたプロセスをforkするだけで良いのでは?と思い始めて実験。でもこれは外れでした。CUDAがforkに対応していないようです。
代わりにforkとsocket pairとpickleでフロントエンドとバックエンドを分けるテストをしてそこそこ動くようになった気がするので試験的に公開してみます。ただソケットペアが完全に安全かはイマイチ確証が持てていないので何かあっても責任は持てません。悪しからず。
(2023年10月23日〜24日追記: コンテキスト引き継いだrunコマンド実装、リモートバックエンド実行のデコレータ追加、戻り値対応、引数渡しのバイナリ化、リモート実行エラーの詳細表示、コールバック対応、酷いバグの修正、タイトル変更)
my_preload_python.py
# preloaderにRCEを組み込むテスト、親一人子二人
# module preloader for fast debug!
# SPDX-License-Identifier: Apache-2.0
# 暴走止める用
# kill -s KILL `ps -ef | grep my_preload_python | grep -v grep | awk '{print $2}'`
import signal
import subprocess
import os
import readline # command history for InteractiveConsole
import code
import setproctitle
import socket
num_cpus = os.cpu_count()
psockets = []
csockets = []
for i in range(num_cpus):
# pipe may be faster but we use socketpair for security
psocket, csocket = socket.socketpair(socket.AF_UNIX, socket.SOCK_RAW) # socket.SOCK_STREAM
psockets.append(psocket)
csockets.append(csocket)
import select
import atexit
import psutil
histpath = os.path.expanduser("~/.python_history")
def save_history(histpath):
readline.write_history_file(histpath)
atexit.register(save_history, histpath)
# バックエンドで実行するデコレータ
# usage:
# @exec_function_on_backend
# def test(x, y):
# r = x + y
# print(r)
#
# test(1, 3)
"""
import marshal, pickle
def exec_function_on_backend(func):
def wrapper(*args, **kwargs):
return send_to_backend((marshal.dumps(func.__code__), func.__name__, args, kwargs))
return wrapper
"""
import dill as pickle
def exec_function_on_backend(func):
def wrapper(*args, **kwargs):
return send_to_backend((func, func.__name__, args, kwargs))
return wrapper
import threading
per_cpu_mutexes = []
per_cpu_fs = []
for i in range(num_cpus):
per_cpu_mutexes.append(threading.Lock())
per_cpu_fs.append(None)
def send_to_backend(data):
cpu = psutil.Process().cpu_num()
# print("debug: cpu: %s"%cpu)
os.sched_setaffinity(0, [cpu])
with per_cpu_mutexes[cpu]: # acquire per-cpu lock
fd = csockets[cpu].fileno()
if not per_cpu_fs[cpu]:
per_cpu_fs[cpu] = os.fdopen(fd, "rb")
send_data(csockets[cpu], data)
# print("send")
csockets[cpu].setblocking(True)
obj = None
try:
# print("blocking read")
obj = pickle.load(per_cpu_fs[cpu])
# print("pickle loaded")
except Exception as e:
print(e)
# print("non-blocking read")
csockets[cpu].setblocking(False)
_ = per_cpu_fs[cpu].read() # バッファ破棄
# print("drop trailing data")
return obj
def send_data(socket, data):
socket.sendall(pickle.dumps(data, recurse=True))
# socket.sendall(pickle.dumps(data))
import types
import traceback
wq_backend_threads = []
def setup_recv_and_workqueues():
for i in range(num_cpus):
thread = threading.Thread(target=recv_and_exec_loop, args=(i,))
wq_backend_threads.append(thread)
for thread in wq_backend_threads:
thread.start()
for thread in wq_backend_threads:
thread.join()
def recv_and_exec_loop(cpu): # no locks, please lock on client side
os.sched_setaffinity(0, [cpu])
f = os.fdopen(psockets[cpu].fileno(), "rb")
r = None
while True:
psockets[cpu].setblocking(True)
try:
obj = pickle.load(f)
# print("backend: data load")
try:
# func = types.FunctionType(marshal.loads(obj[0]), globals(), obj[1])
# r = func(*obj[2], *obj[3])
r = obj[0](*obj[2], *obj[3])
# print("backend: exec func")
except Exception as e:
traceback.print_exc()
except Exception as e:
print(e)
try:
psockets[cpu].setblocking(False)
_ = f.read() # バッファ破棄
# print("backend: drop trailing data")
except:
pass
send_data(psockets[cpu], r)
# print("backend: send data")
term_pid = -1
while True:
if term_pid == -1 or term_pid == child_pid:
child_pid = os.fork()
if child_pid == 0: # Front End Python Process
for psocket in psockets:
psocket.close()
setproctitle.setproctitle("my_preload_python.py - Front End")
console = code.InteractiveConsole(locals=locals())
readline.read_history_file(histpath) # command history
def my_input(prompt):
import re
r = input(prompt)
save_history(histpath)
if (r.startswith("run ")):
m = re.match(r"^run\s+(([^\s]|([\\][\s]))+)\s*$", r)
if m:
fn = m.group(1)
f = open(fn, "r")
s = f.read()
a = compile(s, fn, "exec")
if a:
d = dict(locals(), **globals())
exec(a, d, d)
else:
print("wrong run command: \"%s\""%r)
r = ""
return r
console.raw_input = my_input
console.interact(banner="e.g. run chat_ms_opt_maid2.py")
exit(0)
print(f"Parent process: Front End PID = {child_pid}, Parent PID = {os.getpid()}")
if term_pid == -1 or term_pid == child_pid2:
child_pid2 = os.fork()
if child_pid2 == 0: # Back End Python Process especially for GPU
for csocket in csockets:
csocket.close()
setproctitle.setproctitle("my_preload_python.py - Back End")
setup_recv_and_workqueues()
exit(0)
# Note: don't close the psockets and csockets for next children
print(f"Parent process: Back End PID = {child_pid2}, Parent PID = {os.getpid()}")
if term_pid == -1:
setproctitle.setproctitle("my_preload_python.py - Mothership")
def handle_signals_to_survive(signum, frame):
pass
signal.signal(signal.SIGINT, handle_signals_to_survive) # Ctrl+C
signal.signal(signal.SIGCHLD, handle_signals_to_survive) # when kill the out-of-control child process
term_pid, status = os.waitpid(-1, 0)
使い方
$ cat test.py
import torch
x = torch.tensor([1,2])
print("on frontend: %s"%x)
# print("on frontend: %s"%x.cuda()) # forked frontend process cannot use CUDA
@exec_function_on_backend
def test(x):
import torch
x = x.cuda()
print("on backend: %s"%x)
return x.cpu()
print("returned: %s"%test(x))
@exec_function_on_backend
def callback_test(f):
y = torch.tensor([1,2]).cuda()
f(y)
def callback_handle(s):
print("on callback: %s"%(s)) # Note: callback is executed on backend!
callback_test(callback_handle)
$ python my_preload_python.py
Parent process: Front End PID = 300288, Parent PID = 300287
Parent process: Back End PID = 300289, Parent PID = 300287
e.g. run chat_ms_opt_maid2.py
>>> run test.py
on frontend: tensor([1, 2])
on backend: tensor([1, 2], device='cuda:0')
returned: tensor([1, 2])
on callback: tensor([1, 2], device='cuda:0')