0
1

More than 1 year has passed since last update.

Pythonのmultiprocessを使って並列処理

Posted at

次の記事の続き
https://qiita.com/Octpascal/items/fcfcf172fc8d28f0d874

main

ある関数のメモリ使用量をリアルタイムでプロットするコードを書いた。

方法

  1. multiprocessを使いメモリ監視用の関数を起動
  2. multiprocessを使い測定対象の関数を起動
  3. 測定用関数に対象のPIDを渡す
  4. 対象の関数の終了を検知
  5. 監視用関数を止める

リアルタイムプロットの参考
https://qiita.com/dendensho/items/79e9d2e3d4e8eb5061bc

from multiprocessing import Process, Value, Queue
import os
import psutil
import matplotlib.pyplot as plt
import datetime

# リアルタイムプロットの参考
# https://qiita.com/dendensho/items/79e9d2e3d4e8eb5061bc

WAITING_TIME = 0.1
TIME_H_RANGE = 100

def memory_check(pid:int =os.getpid()):
    '''
    与えられたPIDの物理メモリと仮想メモリ使用量を返す。
    Windowsの場合は仮想メモリ=物理メモリ(共有部分を除く)+ページファイル
    '''
    unit = 1024**2
    mem_info = psutil.Process(pid=pid).memory_info()
    pmem = mem_info.rss / unit # 物理メモリ
    vmem = mem_info.vms / unit # 仮想メモリ
    peak = mem_info.peak_pagefile / unit # 仮想メモリのこれまでの最大値
    return pmem, vmem, peak

def plotting(status: Value, PID: Value):
    '''グラフプロット関数
    status :
        0-5の整数値
        ploting関数のステータス
    PID :
        PIDの整数値
    '''

    # 初期化
    times = [0 for i in range(TIME_H_RANGE)]
    pmems = [0 for i in range(TIME_H_RANGE)]
    vmems = [0 for i in range(TIME_H_RANGE)]
    peaks = [0 for i in range(TIME_H_RANGE)]

    plt.ion()
    plt.figure()
    li_p, = plt.plot(times, pmems, label='Physical')
    li_v, = plt.plot(times, vmems, label='Virtual')
    li_a, = plt.plot(times, peaks, label='Peak Virtual')

    plt.legend()

    plt.xlabel('time')
    plt.ylabel('used memory (MB)')
    plt.title('memory real time plot')

    start_time = datetime.datetime.now()
    # メインループ
    while True:
        if status.value == 1: # アプリケーション起動待ち
            plt.pause(0.1)
        elif status.value == 2: # 実行中
            current_time = (datetime.datetime.now() - start_time).total_seconds() + WAITING_TIME
            try:
                mem = memory_check(PID.value)
            except psutil.NoSuchProcess:
                pass

            times.append(current_time)
            times.pop(0)
            pmems.append(mem[0])
            pmems.pop(0)
            vmems.append(mem[1])
            vmems.pop(0)
            peaks.append(mem[2])
            peaks.pop(0)

            li_p.set_xdata(times)
            li_p.set_ydata(pmems)
            li_v.set_xdata(times)
            li_v.set_ydata(vmems)
            li_a.set_xdata(times)
            li_a.set_ydata(peaks)

            plt.xlim(times[0], times[-1])
            plt.ylim(0, peaks[-1]*1.1)
            plt.draw()

            plt.pause(WAITING_TIME)
        elif status.value == 3: # グラフの更新を停止する
            plt.text(times[-1], peaks[-1], 'Peak Virtual Memory: {:.3f} MB'.format(peaks[-1]), ha='right')
            plt.ioff()
            status.value = 4
        elif status.value == 4: # 終了待ち
            plt.pause(0.1)
        elif status.value == 5: # グラフを閉じる
            plt.clf()
            plt.close()
            return 0
        else: # 異常値
            raise ValueError("statusが異常です")

def local_func(func, q, *args, **kwargs):
    result = func(*args, **kwargs)
    q.put(result)

def realtime_mem_plot(func):
    def wrapper(*args, **kwargs):
        status = Value('i', 0)
        PID = Value('i', 0)
        result_queue = Queue()
        p_plot = Process(target=plotting, args=(status, PID))
        p_plot.start()

        p_main = Process(target=local_func, args=(func, result_queue, *args,), kwargs=kwargs)
        p_main.start()

        print('start plotting')
        parent = psutil.Process(os.getpid())
        children = parent.children()
        PID.value = children[-1].pid
        print(f'PID: {children[-1].pid}')

        status.value = 2

        p_main.join()
        status.value = 3

        input("Please press Enter key to close")
        result = result_queue.get()

        status.value = 5
        p_plot.join()
        return result
    return wrapper

from time import sleep

# @realtime_mem_plot
def f(name, name2, foo=2):
    print(name)
    for j in range(10):
        k = []
        for i in range(100):
            k.append([0]*10000*j)
            sleep(0.01)
        del k
    for i in range(foo):
        print(name2)
    return 10, 20

if __name__ == '__main__':
    f_with_plot = realtime_mem_plot(f)
    result = f_with_plot('test', 'test2', foo=2)
    # result = f('test', 'test2')
    print(result)

前回よりのアップデート

  • 測定対象の関数fの戻り値を受け取ることができるようになった。
    • ラッパー関数からラッパー関数を呼び出すような形にしてあるのがコツ。
  • Peakメモリの取得と、終了時にその値を表示できるようにした。
  • datetimeモジュールを用いることで横軸が正しい値となった。
  • 関数の引数でkwargsを使用できるようにした。

今後の勉強

相変わらずデコレータを使うとエラーが発生する。
調査の結果drillモジュールを使用すると実装できる可能性がありそうだ。
今回local_func関数を導入したが、これを使えば実装が容易そうなので今後完成させたい。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1