0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Pythonでの開発効率化のための手探り(フロントエンドとバックエンドを分ける)

Last updated at Posted at 2023-10-22

昨今はPythonモジュールや機械学習のウェイトの読み込みが遅くなってきて開発効率が悪くなってきたので何か良い仕組み作れないかなと色々手探りしてみました。

とりあえずクラサバ式にしようと思い、IPCや共有メモリを導入しようと思ったのですが、途中でプリロードしたプロセスをforkするだけで良いのでは?と思い始めて実験。でもこれは外れでした。CUDAがforkに対応していないようです。

代わりにforkとsocket pairとpickleでフロントエンドとバックエンドを分けるテストをしてそこそこ動くようになった気がするので試験的に公開してみます。ただソケットペアが完全に安全かはイマイチ確証が持てていないので何かあっても責任は持てません。悪しからず。

(2023年10月23日〜24日追記: コンテキスト引き継いだrunコマンド実装、リモートバックエンド実行のデコレータ追加、戻り値対応、引数渡しのバイナリ化、リモート実行エラーの詳細表示、コールバック対応、酷いバグの修正、タイトル変更)

my_preload_python.py
# preloaderにRCEを組み込むテスト、親一人子二人
# module preloader for fast debug!
# SPDX-License-Identifier: Apache-2.0

# 暴走止める用
# kill -s KILL `ps -ef | grep my_preload_python | grep -v grep | awk '{print $2}'`

import signal
import subprocess
import os

import readline # command history for InteractiveConsole
import code

import setproctitle

import socket

num_cpus = os.cpu_count()

psockets = []
csockets = []

for i in range(num_cpus):
    # pipe may be faster but we use socketpair for security
    psocket, csocket = socket.socketpair(socket.AF_UNIX, socket.SOCK_RAW) # socket.SOCK_STREAM
    psockets.append(psocket)
    csockets.append(csocket)

import select

import atexit
import psutil

histpath = os.path.expanduser("~/.python_history")

def save_history(histpath):
    readline.write_history_file(histpath)

atexit.register(save_history, histpath)



# バックエンドで実行するデコレータ
# usage:
# @exec_function_on_backend
# def test(x, y):
#     r = x + y
#     print(r)
#
# test(1, 3)
"""
import marshal, pickle
def exec_function_on_backend(func):
    def wrapper(*args, **kwargs):
        return send_to_backend((marshal.dumps(func.__code__), func.__name__, args, kwargs))
    return wrapper
"""

import dill as pickle
def exec_function_on_backend(func):
    def wrapper(*args, **kwargs):
        return send_to_backend((func, func.__name__, args, kwargs))
    return wrapper

import threading
per_cpu_mutexes = []
per_cpu_fs = []


for i in range(num_cpus):
    per_cpu_mutexes.append(threading.Lock())
    per_cpu_fs.append(None)

def send_to_backend(data):
    cpu = psutil.Process().cpu_num()
#    print("debug: cpu: %s"%cpu)
    os.sched_setaffinity(0, [cpu])
    with per_cpu_mutexes[cpu]: # acquire per-cpu lock
        fd = csockets[cpu].fileno()

        if not per_cpu_fs[cpu]:
            per_cpu_fs[cpu] = os.fdopen(fd, "rb")

        send_data(csockets[cpu], data)
#        print("send")
        csockets[cpu].setblocking(True)
        obj = None
        try:
#            print("blocking read")
            obj = pickle.load(per_cpu_fs[cpu])
#            print("pickle loaded")
        except Exception as e:
            print(e)

#        print("non-blocking read")
        csockets[cpu].setblocking(False)
        _ = per_cpu_fs[cpu].read() # バッファ破棄
#        print("drop trailing data")

        return obj

def send_data(socket, data):
    socket.sendall(pickle.dumps(data, recurse=True))
#    socket.sendall(pickle.dumps(data))

import types
import traceback

wq_backend_threads = []
def setup_recv_and_workqueues():
    for i in range(num_cpus):
        thread = threading.Thread(target=recv_and_exec_loop, args=(i,))
        wq_backend_threads.append(thread)

    for thread in wq_backend_threads:
        thread.start()

    for thread in wq_backend_threads:
        thread.join()

def recv_and_exec_loop(cpu): # no locks, please lock on client side
    os.sched_setaffinity(0, [cpu])

    f = os.fdopen(psockets[cpu].fileno(), "rb")
    r = None
    while True:
        psockets[cpu].setblocking(True)
        try:
            obj = pickle.load(f)
#            print("backend: data load")
            try:
#                func = types.FunctionType(marshal.loads(obj[0]), globals(), obj[1])
#                r = func(*obj[2], *obj[3])
                r = obj[0](*obj[2], *obj[3])
#                print("backend: exec func")
            except Exception as e:
                traceback.print_exc()
        except Exception as e:
            print(e)
        try:
            psockets[cpu].setblocking(False)
            _ = f.read() # バッファ破棄
#            print("backend: drop trailing data")
        except:
            pass
        send_data(psockets[cpu], r)
#        print("backend: send data")


term_pid = -1

while True:
    if term_pid == -1 or term_pid == child_pid:
        child_pid = os.fork()

    if child_pid == 0: # Front End Python Process
        for psocket in psockets:
            psocket.close()

        setproctitle.setproctitle("my_preload_python.py - Front End")

        console = code.InteractiveConsole(locals=locals())

        readline.read_history_file(histpath) # command history
        def my_input(prompt):
            import re

            r = input(prompt)
            save_history(histpath)
            if (r.startswith("run ")):
                m = re.match(r"^run\s+(([^\s]|([\\][\s]))+)\s*$", r)
                if m:
                    fn = m.group(1)
                    f = open(fn, "r")
                    s = f.read()
                    a = compile(s, fn, "exec")
                    if a:
                        d = dict(locals(), **globals())
                        exec(a, d, d)
                else:
                    print("wrong run command: \"%s\""%r)
                r = ""
            return r
        console.raw_input = my_input

        console.interact(banner="e.g. run chat_ms_opt_maid2.py")

        exit(0)

    print(f"Parent process: Front End PID = {child_pid}, Parent PID = {os.getpid()}")

    if term_pid == -1 or term_pid == child_pid2:
        child_pid2 = os.fork()

    if child_pid2 == 0: # Back End Python Process especially for GPU
        for csocket in csockets:
            csocket.close()
        setproctitle.setproctitle("my_preload_python.py - Back End")

        setup_recv_and_workqueues()
        exit(0)

    # Note: don't close the psockets and csockets for next children

    print(f"Parent process: Back End PID = {child_pid2}, Parent PID = {os.getpid()}")

    if term_pid == -1:
        setproctitle.setproctitle("my_preload_python.py - Mothership")
        def handle_signals_to_survive(signum, frame):
            pass
        signal.signal(signal.SIGINT, handle_signals_to_survive) # Ctrl+C
        signal.signal(signal.SIGCHLD, handle_signals_to_survive) # when kill the out-of-control child process

    term_pid, status = os.waitpid(-1, 0)

使い方

$ cat test.py
import torch

x = torch.tensor([1,2])
print("on frontend: %s"%x)

# print("on frontend: %s"%x.cuda()) # forked frontend process cannot use CUDA

@exec_function_on_backend
def test(x):
    import torch
    x = x.cuda()
    print("on backend: %s"%x)
    return x.cpu()

print("returned: %s"%test(x))

@exec_function_on_backend
def callback_test(f):
    y = torch.tensor([1,2]).cuda()
    f(y)

def callback_handle(s):
    print("on callback: %s"%(s)) # Note: callback is executed on backend!

callback_test(callback_handle)

$ python my_preload_python.py 
Parent process: Front End PID = 300288, Parent PID = 300287
Parent process: Back End PID = 300289, Parent PID = 300287
e.g. run chat_ms_opt_maid2.py
>>> run test.py
on frontend: tensor([1, 2])
on backend: tensor([1, 2], device='cuda:0')
returned: tensor([1, 2])
on callback: tensor([1, 2], device='cuda:0')
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?