Help us understand the problem. What is going on with this article?

【5行くらいで書ける】keras.wrappers.scikit_learnにおけるResourceExhaustedError対策

More than 1 year has passed since last update.

概要

Kerasのモデルをscikit-learnライクに扱えるkeras.wrappers.scikit_learnを使うことで、GridSearchCVなどでハイパーパラメータサーチが簡単にできるようになります(ハイパーパラメータサーチを高度化・簡略化するこんなクラスも作ってみました)。

一方、ハイパーパラメータサーチのようにモデル構築を繰り返す場合、ResourceExhaustedErrorが発生することがあります。

ResourceExhaustedErrorの対策の1つとして、ことあるごとにclear_sessionを実行するというものがありますが、度々サーチが停止しては面倒です。なのでシンプルにfit前に毎回clear_sessionを実行することにします。

※ ここではgpu環境、バックエンドをtensorflowの想定です。

実装

まずは必要モジュールをインポート。

import tensorflow as tf
import keras.backend.tensorflow_backend as KTF
from keras.wrappers.scikit_learn import KerasClassifier

次にKerasClassifierを継承し、クラスを1つ定義。

class MyKerasClassifier(KerasClassifier):
    def fit(self, *args, **kwargs):
        KTF.clear_session()
        session = tf.Session("")
        KTF.set_session(session)
        super().fit(*args, **kwargs)

これで対策はOK。このクラスでモデルを定義すれば、ResourceExhaustedErrorの発生を抑えられます。

n_classes = 10
def mk_nw(activation, lr, out_dim):
    model = Sequential()
    model.add(Conv2D(20, kernel_size=5, strides=1, 
                     activation=activation, input_shape=Xtrain.shape[1:]))
    model.add(MaxPool2D(2, strides=2))

    model.add(Conv2D(50, kernel_size=5, strides=1, activation=activation))
    model.add(MaxPool2D(2, strides=2))

    model.add(Flatten())
    model.add(Dense(out_dim, activation=activation))
    model.add(Dense(n_classes, activation="softmax"))

    model.compile(loss="categorical_crossentropy", optimizer=optimizers.SGD(lr=lr))
    return model

estimator = MyKerasClassifier(mk_nw, activation="linear", 
                              lr=0.01, out_dim=256, epochs=16, verbose=0)

注意

MyKerasClassifierをfitする(=clear_sessionを実行する)と、clear_sessionの実装に書いてある通り、グラフ上のモデルがすべてDestroyされます。必要なモデルがグラフ上に乗っている場合は事前に退避させてください。

clear_sessionの実装

https://github.com/keras-team/keras/blob/master/keras/backend/tensorflow_backend.py

def clear_session():
    """Destroys the current TF graph and creates a new one.

    Useful to avoid clutter from old models / layers.
    """
    global _SESSION
    global _GRAPH_LEARNING_PHASES
    tf.reset_default_graph()
    reset_uids()
    _SESSION = None
    phase = tf.placeholder_with_default(False,
                                        shape=(),
                                        name='keras_learning_phase')
    _GRAPH_LEARNING_PHASES = {}
    _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = phase
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away