背景
- Matplotlibをmultiprocessing.Poolで並列化すると画像が乱れることがあった
- 問題なく描画できることも多かったので発生条件の絞り込みに時間がかかった
問題の再現
- Poolで分岐する前にpyplotを呼んでいる場合のみ、画像の乱れが生じた
- 以下のプログラムでto_failフラグをFalseにすると正しく描画できる
import multiprocessing as mp
import matplotlib.pyplot as plt
import numpy as np
def myplot(x):
rng = np.random.RandomState(10 ** 6 * x)
data = rng.randn(10, 10)
fig, ax = plt.subplots()
c = ax.imshow(data)
fig.colorbar(c)
fig.savefig(f"img/{x:02d}.png")
plt.close(fig)
def multi_fail(n, to_fail):
if to_fail:
myplot(999)
with mp.Pool(4) as p:
p.map(myplot, range(n))
multi_fail(10, to_fail=True)
対策
上手くいくこともあるが、matplotlibは安易に並列化しないほうが良さそうだ
追記(解決)
-
multiprocessing
モジュールはUnix上ではデフォルトでメインプロセスのコピーを立てるが、context
を指定することでサブプロセスの生成方法を変更できる。
https://docs.python.org/ja/3/library/multiprocessing.html#contexts-and-start-methods - これを利用してプログラムを下記のように書き換えると、問題が生じなくなった。当初のプログラムで
to_fail
フラグを立てると問題が生じるということは、メインプロセスの状態がサブプロセスたちに影響しているということなので、context
をfork
からspawn
に変更して綺麗なサブプロセスを作るのが正攻法だ。- いや、そもそも
os.fork
とmatplotlib
が両方正しく動いていればfork
のままで問題ないのだが。
- いや、そもそも
- プログラムの実行部を
if __name__ == "__main__":
で括る必要がある。そうしないと子プロセスが孫プロセスを、孫プロセスが曾孫プロセスを、、、と無限ループが生じる。
# myplot関数の宣言は上記と同じ
import multiprocessing as mp
def multi_ok(n):
myplot(999)
with mp.get_context("spawn").Pool(4) as p: # 変更
p.map(myplot, range(n))
if __name__ == "__main__": # 変更
multi_ok(10)