概要
Kerasのモデルをscikit-learnライクに扱えるkeras.wrappers.scikit_learnを使うことで、GridSearchCV
などでハイパーパラメータサーチが簡単にできるようになります(ハイパーパラメータサーチを高度化・簡略化するこんなクラスも作ってみました)。
一方、ハイパーパラメータサーチのようにモデル構築を繰り返す場合、ResourceExhaustedErrorが発生することがあります。
ResourceExhaustedErrorの対策の1つとして、ことあるごとにclear_session
を実行するというものがありますが、度々サーチが停止しては面倒です。なのでシンプルに__fit前に毎回clear_session
を実行する__ことにします。
※ ここではgpu環境、バックエンドをtensorflowの想定です。
実装
まずは必要モジュールをインポート。
import tensorflow as tf
import keras.backend.tensorflow_backend as KTF
from keras.wrappers.scikit_learn import KerasClassifier
次にKerasClassifier
を継承し、クラスを1つ定義。
class MyKerasClassifier(KerasClassifier):
def fit(self, *args, **kwargs):
KTF.clear_session()
session = tf.Session("")
KTF.set_session(session)
super().fit(*args, **kwargs)
これで対策はOK。このクラスでモデルを定義すれば、ResourceExhaustedErrorの発生を抑えられます。
n_classes = 10
def mk_nw(activation, lr, out_dim):
model = Sequential()
model.add(Conv2D(20, kernel_size=5, strides=1,
activation=activation, input_shape=Xtrain.shape[1:]))
model.add(MaxPool2D(2, strides=2))
model.add(Conv2D(50, kernel_size=5, strides=1, activation=activation))
model.add(MaxPool2D(2, strides=2))
model.add(Flatten())
model.add(Dense(out_dim, activation=activation))
model.add(Dense(n_classes, activation="softmax"))
model.compile(loss="categorical_crossentropy", optimizer=optimizers.SGD(lr=lr))
return model
estimator = MyKerasClassifier(mk_nw, activation="linear",
lr=0.01, out_dim=256, epochs=16, verbose=0)
注意
MyKerasClassifier
をfitする(=clear_session
を実行する)と、clear_session
の実装に書いてある通り、グラフ上のモデルがすべて__Destroy__されます。必要なモデルがグラフ上に乗っている場合は事前に退避させてください。
clear_session
の実装
def clear_session():
"""Destroys the current TF graph and creates a new one.
Useful to avoid clutter from old models / layers.
"""
global _SESSION
global _GRAPH_LEARNING_PHASES
tf.reset_default_graph()
reset_uids()
_SESSION = None
phase = tf.placeholder_with_default(False,
shape=(),
name='keras_learning_phase')
_GRAPH_LEARNING_PHASES = {}
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = phase