LoginSignup
7
7

More than 5 years have passed since last update.

luigiでパラメータチューニング(2)

Last updated at Posted at 2015-12-05

ニッチなことしか書いていない自分にしては反応が大きかった気がしたのでもう少し色々やってみた。という話です。

汎用タスクを作ろう

今回のコードはこちら。前回とやっていることは同じです。
https://github.com/keisuke-yanagisawa/study/blob/20151205/luigi/param_tuning.py

今回のテーマは、汎用のタスクを作りたいということ。
いろんなタスクを作るわけですが、やっぱり汎用のものを作ると嬉しいのはどこでもいっしょ。そこで、

  • パラメータチューニングの結果を集約し
  • 最良値を出したパラメータだけをまとめてcsv形式で出力する

ということを汎用的にやってくれるタスクを作ってみました。

class param_tuning(luigi.Task):
    tasks        = luigi.Parameter()              # luigi.Taskの1次元配列
    text_format  = luigi.Parameter()              # pythonの「変数名記述付きの」正規表現を渡す
    reduce_pivot = luigi.Parameter()              # どの変数を集約で利用するか
    reduce_rule  = luigi.Parameter(default="min") # 集約する関数を指定, min or max
    out_file     = luigi.Parameter()              # 出力ファイル名

    def requires(self):
        return self.tasks;
    def output(self):
        return luigi.LocalTarget(self.out_file)

    def run(self):

        # making pandas dataframe
        results = []
        for task in self.requires():
            with task.output().open() as taskfile:
                string = taskfile.read()
                groupdict = re.search(self.text_format, string).groupdict()
                results.append(groupdict)
        df = pd.DataFrame.from_dict(results);
        df[self.reduce_pivot] = convert2num(df[self.reduce_pivot])
        values = df[self.reduce_pivot]

        # Aggregation of parameter tuning results
        if self.reduce_rule == "min":
            best_val = min(values)
        elif self.reduce_rule == "max":
            best_val = max(values)
        else:
            print("reduce_rule must be min or max. your input is %s" % self.reduce_rule)
            exit(1);

        # Rearrangement of column order
        column_order = filter(lambda key: key != self.reduce_pivot, df.columns) + [self.reduce_pivot]
        df = df[column_order]

        # Outputting results as csv formatted data
        df[df[self.reduce_pivot] == best_val].to_csv(self.output().fn, index=False);

集約関係は色々コーディングが面倒になり、pandasに任せました。
やっていることは非常に簡単で、

  1. 一つ一つのパラメータでの計算をrequires()で実行し
  2. 結果をすべてpandas dataframeに集計
  3. 最も良い値を集約部分で求め
  4. (csvの最後に集約で用いた値がくるように編集して)
  5. 最も良い値になっているものだけを(場合によっては複数個)出力する

という仕組みです。

inputはちょっとややこしいですね。tasksやreduce系, outputと基本的にはわかると思うのですが、汎用にするために正規表現をぶち込むようなインターフェースにしたらなんかキモくなりました。

使い方

githubに上げているコード自体の利用方法は、

python param_tuning.py main_task --local-scheduler

とかやってくれれば動く気がします。

また、この汎用タスクの利用方法に関してですが、別途「パラメータでの計算を行うタスク」「main関数となるタスク」を用意します。

mainタスクに記述する正規表現が多分一番の難点でして(私が正規表現使ってなさすぎ)、それについて説明。
今回の計算実行タスクtask_param_evalcost,gamma,errorという一行のcsv形式のファイルを出力するようになっているので、以下のように指定します。

s = "[-+]?\d*\.\d+|\d+" ## float or int expression
text_format  = "(?P<cost>"+s+"),(?P<gamma>"+s+"),(?P<error>"+s+")"

?P<name>とすることで、名前を指定することができます。これは出力csvのheaderやpivotの名前として利用するので、ちゃんと指定してください。

終わりに

私、コーディングしている割には絶望的にモジュール化が下手な人間なのですが、こういう強制力が働くと「仕方なく」切ったりすることができてすごく良いと思っています。今回は、そのモジュール化によって得られる恩恵の1つである汎用物品の作成をしてみました。
…コーディングがそもそもへたっぴなのはご指導ご鞭撻頂戴できれば幸いです。

参考資料

数字を正規表現で発見する方法に関しては以下のstack overflowから拝借。
http://stackoverflow.com/questions/4703390/how-to-extract-a-floating-number-from-a-string-in-python

7
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
7