LoginSignup
1
3

More than 3 years have passed since last update.

scikit-optimizeのEarlyStopperで最適化を中断する

Posted at

前置き

前回記事があります。
Pythonでめっちゃ簡単にできるベイズ最適化

長い間、サンプリングを続ければ良いものの、一定の評価値に達したら中断したい場合、scikit-optimizeに用意されているCallbackで止めるようにしないと、あらかじめ指定したサンプリング回数が終わるまで止められません。
プロセスキルすれば止まるけど、そこまでの結果を受け取って取り扱うことができない。

というわけで、scikit-optimize.callbacks.EarlyStopperを使って、任意の条件で最適化を中断したいが、いかんせんこのEarlyStopperは使い方がろくに説明されていない。
help()で見ろとか書いてあるんだが、出てくるのは下記のリンクと同レベル。

というわけで、本記事では日本語で使い方を記録しておきたいと思います。

skopt.callbacks.EarlyStopper

In [0]:from skopt.callbacks import EarlyStopper
In [1]:help(EarlyStopper)
Help on class EarlyStopper in module skopt.callbacks:

class EarlyStopper(builtins.object)
 |  Decide to continue or not given the results so far.
 |  
 |  The optimization procedure will be stopped if the callback returns True.
 |  
 |  Methods defined here:
 |  
 |  __call__(self, result)
 |      Parameters
 |      ----------
 |      result : `OptimizeResult`, scipy object
 |          The optimization as a OptimizeResult object.
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)

ソースコード全文

前回でも解説しているので、変更のある部分で抜粋して解説する。

import numpy as np
from skopt import gp_minimize
from skopt.callbacks import EarlyStopper

class Stopper(EarlyStopper):
    def __call__(self, result):
        ret = False
        if result.fun < -1.0:
            ret = True
        return ret

def func(param=None):
    ret = np.cos(param[0] + 2.34) + np.cos(param[1] - 0.78)
    return -ret

if __name__ == '__main__':
    x1 = (-np.pi, np.pi)
    x2 = (-np.pi, np.pi)
    x = (x1, x2)
    result = gp_minimize(func, x, 
                          n_calls=30,
                          noise=0.0,
                          model_queue_size=1,
                          callback=[Stopper()],
                          verbose=True)

EarlyStopperの継承クラス

from skopt.callbacks import EarlyStopper

class Stopper(EarlyStopper):
    def __call__(self, result):
        ret = False
        if result.fun < -1.0:
            ret = True
        return ret

なんとなく、予想はできたけど、継承して自前で実装せよということだった。
というわけで、EarlyStopperをインポートして、自作クラスStopperを作っている。
コンストラクタとかは分からないけどオーバーライドする必要なし。
__call__をオーバーライドするだけでよい。
この時、__call__resultという引数を設けているが、最適化の最中に都度都度の最新のresultが渡されてくる。
今回はこのresultの中にあるfunというメンバーを参照し、この値が-1.0を下回っていたらTrue、そうでなければFalseを返す仕組みにした。
この__call__がTrueを返すと中断、Falseなら続行なのである。

gp_minimizeへ設定

    result = gp_minimize(func, x, 
                          n_calls=30,
                          noise=0.0,
                          model_queue_size=1,
                          callback=[Stopper()],
                          verbose=True)

前回では設定しなかったcallbackという引数にStopperクラスのインスタンスを渡しておく。
この時、Listで渡しているのは複数のCallbackを受け付けることができるようになっているため。
一つでもListで渡しておく。
それ以外は前回のまま。
これでサンプリング回数30回以下で止まればよいが、当然-1.0という閾値に引っかからなければ止まらない。
※が、問題が簡単すぎるので、絶対にイケる。

結果

Iteration No: 1 started. Evaluating function at random point.
Iteration No: 1 ended. Evaluation done at random point.
Time taken: 0.0000
Function value obtained: -0.9218
Current minimum: -0.9218
Iteration No: 2 started. Evaluating function at random point.
Iteration No: 2 ended. Evaluation done at random point.
Time taken: 0.0000
Function value obtained: 0.9443
Current minimum: -0.9218
Iteration No: 3 started. Evaluating function at random point.
Iteration No: 3 ended. Evaluation done at random point.
Time taken: 0.0000
Function value obtained: 1.6801
Current minimum: -0.9218
Iteration No: 4 started. Evaluating function at random point.
Iteration No: 4 ended. Evaluation done at random point.
Time taken: 0.0000
Function value obtained: -0.0827
Current minimum: -0.9218
Iteration No: 5 started. Evaluating function at random point.
Iteration No: 5 ended. Evaluation done at random point.
Time taken: 0.0000
Function value obtained: -1.1247
Current minimum: -1.1247
Iteration No: 6 started. Evaluating function at random point.

5回目のサンプリングで評価値が-1.1247と-1.0を下回ったので、6回目をやろうとして止まった。
成功。
まぁ、絶対にひっかかるであろう閾値にはしたが・・・

蛇足

他にもscikit-optimizeにはいろいろなCallbackが用意されているが、任意の条件でとなるとこのEarlyStopperしかない。
で、多分継承して自前で実装しなければならないのも、このEarlyStopperだけと思う。
他のはインスタンス作成時にパラメータを設定しておけば、今回のもの同様に設定できるはず。
ちょっとどう考えればよいかわかりにくいので、使ってない・・・

skopt.callbacks

1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3