scikit-learnのParallelで並列処理

More than 5 years have passed since last update.

scikit-learnのチュートリアルといえば@Scaled_Wurmマンの紹介がとても分かりやすい。

今回はたまたまソースコードを読んでいて、かつそのブログエントリでは紹介さていなかったニッチなところをメモする。

結論としては特に理由が無いならmultiprocessingで書いている部分はParallelで置き換えても良さそう、ということ。


Parallel


  • 並列処理をおこなうクラス

  • 基本的にはmultiprocessingで並列処理をおこなうんだけど、あったらいいな〜というヘルプ機能を提供してくれる。


multiprocessingじゃダメなの?

Parallelは (原文はソースコード中のNotes)


  • 関数の引数をリストで作らなくもていいよ

  • debugが簡単だよ


    • コードを変更しなくてもn_jobs=1にしたら並列機能を無くせる

    • 子プロセスで吐いたエラーもtracebackしてくれる



  • 進捗度合いを見れるよ

  • 面倒なことしなくてもmultiprocessingのジョブをCtrl-Cで中止できるよ (参考)


引数


  • n_jobs (int): 並列処理するcpuの数

  • verbose (int): 出力するメッセージの詳細さ。10以上ですべての反復でメッセージ出力。

  • pre_dispatch: 並列処理を開始する前に呼び出されるcpuの数



  • 例を見たほうが早い

  • delayedは並列処理をおこないたい関数のラッパーで、関数とその引数を返す。Parallelのcall()が呼び出された時にそれらの関数が実行される。


簡単な例

>>> from math import sqrt

>>> from sklearn.externals.joblib import Parallel, delayed
>>> Parallel(n_jobs=1)(delayed(sqrt)(i**2) for i in range(10))
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]


進捗を見る例

>>> from time import sleep

>>> from sklearn.externals.joblib import Parallel, delayed
>>> r = Parallel(n_jobs=2, verbose=5)(delayed(sleep)(.1) for _ in range(10)) #doctest: +SKIP
[Parallel(n_jobs=2)]: Done 1 out of 10 | elapsed: 0.1s remaining: 0.9s
[Parallel(n_jobs=2)]: Done 3 out of 10 | elapsed: 0.2s remaining: 0.5s
[Parallel(n_jobs=2)]: Done 6 out of 10 | elapsed: 0.3s remaining: 0.2s
[Parallel(n_jobs=2)]: Done 9 out of 10 | elapsed: 0.5s remaining: 0.1s
[Parallel(n_jobs=2)]: Done 10 out of 10 | elapsed: 0.5s finished


  • verboseに入れる値が大きいほどたくさんメッセージを出してくれる

  • ジョブの数、経過時間などを出力してくれる


pre_dispatchを指定する


  • defaultは'all'

  • 一つのジョブが使うメモリが多かったら減らしたりして使う感じなのかな

  • 下の例では3 (1.5 * 2 jobs) つのジョブが並列処理をおこなう前に呼び出されている

>>> from math import sqrt

>>> from sklearn.externals.joblib import Parallel, delayed

>>> def producer():
... for i in range(6):
... print('Produced %s' % i)
... yield i

>>> out = Parallel(n_jobs=2, verbose=100, pre_dispatch='1.5*n_jobs')(
... delayed(sqrt)(i) for i in producer()) #doctest: +SKIP
Produced 0 ### 1つ目
Produced 1 ### 2つ目
Produced 2 ### 3つ目
[Parallel(n_jobs=2)]: Done 1 jobs | elapsed: 0.0s
Produced 3
[Parallel(n_jobs=2)]: Done 2 jobs | elapsed: 0.0s
Produced 4
[Parallel(n_jobs=2)]: Done 3 jobs | elapsed: 0.0s
Produced 5
[Parallel(n_jobs=2)]: Done 4 jobs | elapsed: 0.0s
[Parallel(n_jobs=2)]: Done 5 out of 6 | elapsed: 0.0s remaining: 0.0s
[Parallel(n_jobs=2)]: Done 6 out of 6 | elapsed: 0.0s finished