前回はこちら
はじめに
今日はspark
の標準ライブラリであるMLlibを利用してみます。
spark
にはMLlib
を含め、4種類の標準ライブラリがあります。
他のライブラリに関しては、こちらを参照ください。
MLlibとは
spark上に実装されている機械学習のライブラリです。
機械学習の実装はすごくムズカシイのですが、MLlibは割と簡単に試すことができ、サポートしているアルゴリズムも豊富なのでとっつきやすい気がします(私見です)。
とりあえず、協調フィルタリングを試してみます。
協調フィルタリングについてはこのブログが大変参考になりました。
実装手順
1. 必要なライブラリをimportする
from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating
2. データを読み込んでRDD
を生成する
テストデータはコレを利用します。
レビューした人のID、レビューされたもののID、レビュー結果といった感じでしょうか?
data = sc.textFile("data/mllib/als/test.data")
3. 適当に変数を設定します
rank:特徴量数、numIterations:繰り返し数といった変数を設定します。
rankは増やせば増やすほど性能が改善するらしいです。10ぐらいから変わらなくなるらしいのでとりあえず10にしときます。
numIterationsも増やせば増やすだけ学習サイクルが増えるので精度が上がる筈です。
これもとりあえず10回にしときます。
rank = 10
numIterations = 10
4. modelを作成して、data
を学習させます
- で読み込んだライブラリの
train
メソッドを叩くだけです。
model = ALS.train(ratings, rank, numIterations)
5. 適当に値をぶっこんで予測をしてもらいます。
>>> model.predict(2,4)
1.0015225077674874
(2,4)
に対する正解は1なのでまぁまぁいいのかな?
6. モデルを保存したり、ロードしたりします。
折角学習させたデータなので、保存して任意のタイミングで利用できるようにしたいですね。
save
とMatrixFactorizationModel.load
を利用すればそれも簡単にできるそうです。
model.save(sc, "own/model/path")
sameModel = MatrixFactorizationModel.load(sc, "own/model/path")
終わりに
ALS
を利用するところはわりかしさくっとできました。
どちらかというと、アルゴリズムの選定や理解の方が厄介そうです。
それくらいAPIは素晴らしいので何かしら機械学習を試してみたい方は
spark
から始めてもよいとおもいました。
こちらに試せるアルゴリズムが紹介されているのでご参照ください。
次回は、「mySQL」とつなぐ、仮想サーバを用いて分散処理を試してみる
のどちらかを書こうと思います。
おわり