LoginSignup
29
23

More than 5 years have passed since last update.

PySparkで学習済みのscikit-learnのモデルを使う

Last updated at Posted at 2016-01-22

やりたいこと

Sparkで機械学習といえばMLlibだけど、まだまだscikit-learnには機能面で劣っているように思えます。例えば、scikit-learnでは学習時に正例と負例の数が不均等な場合の補正とかできますが、mllibの1.5ではまだそのような機能はありません1。こんな時にメモリに乗る程度のデータで事前にscikit-learnで学習器を作成しておき、それをpysparkで大規模データの予測に使用できるとsklearnとsparkの両者のメリットが活かせるのではと思っています。

Let's Try

方針

データをndarrayのRDDに変換し、RDDのmapで学習済みのモデルのpredictに渡せばいいのですが、そのままやると関数呼び出しのオーバーヘッドが大きそうなのである程度の大きさのバッチ単位で処理したいと思います。

事前準備

Anaconda等のscikit-learnが使えるPython環境を全てのsparkノードに同じパス(/opt/anacondaとか)に準備しておきます。spark-submitコマンドを実行する際にPYSPARK_PYTHON=/opt/anaconda/bin/python3を指定してあげればこのPythonが使用されます。

実装例

事前に学習モデルを作成。今回はRandomForestを使用。データは適当です。

import numpy as np
from sklearn import ensemble

N = 1000
train_x = np.random.randn(N, 10)
train_y = np.random.binomial(1, 0.1, N)

model = ensemble.RandomForestClassifier(10, class_weight="balanced").fit(train_x, train_y)

で、これをPySparkで以下のように使用します。

from pyspark import SparkContext
sc = SparkContext()

test_x = np.random.randn(N * 100, 10)
n_partitions = 10
rdd = sc.parallelize(test_x, n_partitions).zipWithIndex()

# Point 1
def batch(xs):
    yield list(xs)

batch_rdd = rdd.mapPartitions(batch)

# Point 2
b_model = sc.broadcast(model)

def split_id_and_data(xs):
    xs = list(xs)
    data = [x[0] for x in xs]
    ids = [x[1] for x in xs]
    return data, ids

# Point 3
result_rdd = batch_rdd.map(split_id_and_data) \
    .flatMap(lambda x: zip(x[1], b_model.value.predict(x[0])))

for _id, pred in result_rdd.take(10):
    print(_id, pred)

sc.stop()

ポイントは次の3点です

  1. mapPartitionsを使用してRDD[ndarray]RDD[list[ndarray]]に変換しておきます。こうすることである程度の塊のデータをmodel.predictにまとめて渡せます。
  2. 学習済みモデルをbroadcastしておきます。
  3. idsとdataを分離し、dataをb_model.value.predictに渡します。これとidsを再度zipしてflatMapに入れてあげれば完成

(2016-01-26 追記)
partitionないでlistにまとめる

# Point 1
def batch(xs):
    yield list(xs)

batch_rdd = rdd.mapPartitions(batch)

の部分は元々glomというメソッドが用意されていました。

batch_rdd = rdd.glom()

(2016-01-26 追記その2)

DStreamにもglomflatMapメソッドはあるのでSparkStreamingの場合も全く同様に使用できます。SVMで異常検出の学習器を作っておき、ストリーミングデータに対してリアルタイムで適用すると言ったこともできそうです。


  1. 一応JIRAには要望が挙がっていてそろそろ実装されたかもしれませんが、CDH5.5ではsparkが1.5なので使えません。 

29
23
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29
23