概要
KerasRegressorをpickleやjoblibで保存しようとするとエラーになりますが、
保存できるようにする方法です。
ソリューション
以下でKerasRegressorをモンキーパッチ
def KerasRegressor__getstate__(self):
result = { 'sk_params': self.sk_params }
with tempfile.TemporaryDirectory() as dir:
if hasattr(self, 'model'): # 親Estimatorによるcloneなどで存在しないケースがある
self.model.save(dir + '/output.h5', include_optimizer=False)
with open(dir + '/output.h5', 'rb') as f:
result['model'] = f.read()
return result
KerasRegressor.__getstate__ = KerasRegressor__getstate__
def KerasRegressor__setstate__(self, serialized):
self.sk_params = serialized['sk_params']
with tempfile.TemporaryDirectory() as dir:
model_data = serialized.get('model')
if model_data:
with open(dir + '/input.h5', 'wb') as f:
f.write(model_data)
self.model = models.load_model(dir + '/input.h5')
KerasRegressor.__setstate__ = KerasRegressor__setstate__
解説
__getstate__, __setstate__
を使うと、pickleのシリアライズ、デシリアライズをクラスごとにカスタマイズできる。 (詳細はぐぐって)