0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

KerasRegressorをpickleやjoblibで保存できるようにする方法

Posted at

概要

KerasRegressorをpickleやjoblibで保存しようとするとエラーになりますが、
保存できるようにする方法です。

ソリューション

以下でKerasRegressorをモンキーパッチ

def KerasRegressor__getstate__(self):
    result = { 'sk_params': self.sk_params }
    with tempfile.TemporaryDirectory() as dir:
        if hasattr(self, 'model'): # 親Estimatorによるcloneなどで存在しないケースがある
            self.model.save(dir + '/output.h5', include_optimizer=False)
            with open(dir + '/output.h5', 'rb') as f:
                result['model'] = f.read()
    return result
KerasRegressor.__getstate__ = KerasRegressor__getstate__

def KerasRegressor__setstate__(self, serialized):
    self.sk_params = serialized['sk_params']
    with tempfile.TemporaryDirectory() as dir:
        model_data = serialized.get('model')
        if model_data:
            with open(dir + '/input.h5', 'wb') as f:
                f.write(model_data)
            self.model = models.load_model(dir + '/input.h5')
KerasRegressor.__setstate__ = KerasRegressor__setstate__

解説

__getstate__, __setstate__ を使うと、pickleのシリアライズ、デシリアライズをクラスごとにカスタマイズできる。 (詳細はぐぐって)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?