4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

【5行くらいで書ける】keras.wrappers.scikit_learnにおけるResourceExhaustedError対策

Last updated at Posted at 2018-04-22

概要

Kerasのモデルをscikit-learnライクに扱えるkeras.wrappers.scikit_learnを使うことで、GridSearchCVなどでハイパーパラメータサーチが簡単にできるようになります(ハイパーパラメータサーチを高度化・簡略化するこんなクラスも作ってみました)。

一方、ハイパーパラメータサーチのようにモデル構築を繰り返す場合、ResourceExhaustedErrorが発生することがあります。

ResourceExhaustedErrorの対策の1つとして、ことあるごとにclear_sessionを実行するというものがありますが、度々サーチが停止しては面倒です。なのでシンプルに__fit前に毎回clear_sessionを実行する__ことにします。

※ ここではgpu環境、バックエンドをtensorflowの想定です。

実装

まずは必要モジュールをインポート。

import tensorflow as tf
import keras.backend.tensorflow_backend as KTF
from keras.wrappers.scikit_learn import KerasClassifier

次にKerasClassifierを継承し、クラスを1つ定義。

class MyKerasClassifier(KerasClassifier):
    def fit(self, *args, **kwargs):
        KTF.clear_session()
        session = tf.Session("")
        KTF.set_session(session)
        super().fit(*args, **kwargs)

これで対策はOK。このクラスでモデルを定義すれば、ResourceExhaustedErrorの発生を抑えられます。

n_classes = 10
def mk_nw(activation, lr, out_dim):
    model = Sequential()
    model.add(Conv2D(20, kernel_size=5, strides=1, 
                     activation=activation, input_shape=Xtrain.shape[1:]))
    model.add(MaxPool2D(2, strides=2))

    model.add(Conv2D(50, kernel_size=5, strides=1, activation=activation))
    model.add(MaxPool2D(2, strides=2))

    model.add(Flatten())
    model.add(Dense(out_dim, activation=activation))
    model.add(Dense(n_classes, activation="softmax"))

    model.compile(loss="categorical_crossentropy", optimizer=optimizers.SGD(lr=lr))
    return model

estimator = MyKerasClassifier(mk_nw, activation="linear", 
                              lr=0.01, out_dim=256, epochs=16, verbose=0)

注意

MyKerasClassifierをfitする(=clear_sessionを実行する)と、clear_sessionの実装に書いてある通り、グラフ上のモデルがすべて__Destroy__されます。必要なモデルがグラフ上に乗っている場合は事前に退避させてください。

clear_sessionの実装

def clear_session():
    """Destroys the current TF graph and creates a new one.

    Useful to avoid clutter from old models / layers.
    """
    global _SESSION
    global _GRAPH_LEARNING_PHASES
    tf.reset_default_graph()
    reset_uids()
    _SESSION = None
    phase = tf.placeholder_with_default(False,
                                        shape=(),
                                        name='keras_learning_phase')
    _GRAPH_LEARNING_PHASES = {}
    _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = phase
4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?