20
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

luigiでパラメータチューニング

Last updated at Posted at 2015-12-03

最近、luigiというデータフロー制御フレームワークなるものを利用しているのですが、割と使い勝手が良いと感じているので、ちょっと布教的に文章を書いてみようと思います。
…日本語の資料、少ないんですよね…

luigiそのものについては
http://qiita.com/colspan/items/453aeec7f4f420b91241
http://qiita.com/keisuke-nakata/items/0717c0c358658964f81e
に詳しく書いてありますので、そちらをご参照ください。

簡単に良さを説明すると、luigiはluigi.Taskを継承した子クラス(=タスク)を一つ一つ終了させて、全体の計算結果を得ます。
データの受け渡しをファイルに制限することで、途中でバグがあったり、計算時間limitが超えてしまった場合でも、既に計算された部分を残すことができ、resumeが可能です。(たぶんon-memoryの受け渡しはできない…?→ 追記:luigi.mockというものがあるそうです)

機械学習 with luigi

パラメータチューニングの時にon-memoryで計算していて途中で落ちたりすると、計算全てやり直しで悲しいです。ということで、luigiの良さをつかえるんじゃないかなーと思い、とりあえず、コード書いてみました。
https://github.com/keisuke-yanagisawa/study/blob/20151204/luigi/param_tuning.py
必要なものは

  • numpy
  • scikit-learn
  • luigi

の3つ。

このプログラムは

python param_tuning.py task_param_tuning --local-scheduler

とやると走ります。根となるタスクを指定してあげるわけですね。
あとluigiはschedulerを立ち上げてやるのが普通の使い方なのですが、面倒なので--local-schedulerで単独実行させるようにしています。

では、タスクを見てみましょう。

class task_param_eval(luigi.Task):
    data = luigi.Parameter()
    C = luigi.FloatParameter()
    gamma = luigi.FloatParameter()

    def requires(self):
        return []
    def output(self):
        return luigi.LocalTarget("temp/%s.txt" % hash( frozenset([self.C, self.gamma]) ))
    def run(self):
        model = svm.SVR(C=self.C, gamma=self.gamma)

        # cross_val_score function returns the "score", not "error". 
        # So, the result is inverse of error value.
        results = -cross_validation.cross_val_score(model, data.data, data.target, scoring="mean_absolute_error")
        with self.output().open("w") as out_file:
            out_file.write( str(np.mean(results)) ); 

コード自体はいたって簡単ですね。SVRを使ってクロスバリデーションで評価値を出し、それを平均した値をファイルに出力しています。

luigiのタスクは基本的に**[requires, output, run]の3点セットを上書きする**、と覚えておきましょう。

  • requires ... このタスクを行うためにそもそも実行済みでないといけないタスク
  • output ... このタスクの出力ファイルパス(複数指定可能)
  • run ... タスクの中身

です。出力ファイルパスはluigi.LocalTarget()というおまじないを使います。

また、引数はluigi.Parameter()などを利用します。luigiの内部ではこれらのパラメータを見て、同じタスク名でもパラメータが違ければ実行する、そうでなければ2度同じことは実行しない、を判断しているような気がします。(そのため、Parameterはhashableであることが求められます)

続いて、上記のタスクを複数回呼ぶタスクを見てみましょう。

class task_param_tuning(luigi.Task):

    cost_list = luigi.Parameter(default="1,2,5,10")
    gamma_list = luigi.Parameter(default="1,2,5,10")
    
    data = datasets.load_diabetes()

    def requires(self):
        return flatten_array(
            map(lambda C:
                    map(lambda gamma:
                            task_param_eval(data=frozenset(self.data), # values should be hashable 
                                       C=float(C), gamma=float(gamma)),
                        self.cost_list.split(",")),
                self.gamma_list.split(",")))
    def output(self):
        return luigi.LocalTarget("results.csv")
    def run(self):

        results = {}

        for task in self.requires():
            with task.output().open() as taskfile:
                results[(task.C, task.gamma)] = float(taskfile.read())
        
        best_key = min(results,  key=results.get)
        with self.output().open("w") as out_file:
            out_file.write("%s,%s,%.4f\n" %(best_key[0], best_key[1], results[best_key]))

自分の不勉強で、複数個のパラメータを渡すときの方法をろくに知らないので(怒られそうだ)とりあえずカンマ区切りしてますが、まあそれは置いておいて。
このコードでは、task_param_evalのCやgammaといったパラメータも出力したかったのでrunの中ではfor task in self.requires()としていますが、純粋にrequiresのファイルを読み込めればOKという場合はself.input()とすればself.requires().output()と同じ効果が得られます。

20
19
9

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?