LoginSignup
7
3

More than 5 years have passed since last update.

[TensorFlow] Custom Metrics Functionを使っているmodelのリストア方法

Posted at

CNNでの画像の超解像技術を試すために、下記のようなpsnr関数を作成しそれをmetricsに設定したmodelを構築しました。

# Custom Mertics Function
def psnr(y_true, y_pred):
    return -10*K.log(K.mean(K.flatten((y_true - y_pred))**2)) / np.log(10)
# model Compile
model.compile(
    loss = 'mean_squared_error',
    optimizer = 'adam',
    metrics = [psnr]
)

そのmodelをクラウド環境で学習させ、別環境でロードしようとしたところ「ValueError: Unknown metric function:***」というメッセージが出力されました。

model = load_model('model.h5')
---------------------------------------------------------------------------
ValueError   Traceback (most recent call last)
(省略)
ValueError: Unknown metric function:psnr
---------------------------------------------------------------------------

Custom Functionを使ったmodelを別環境で使用するには、modelをロードする際に引数として、
[custom_objects]を指定するがあるようです。

今回は、psnrというCustom Functionを作成したのでそれを移動先の環境でも定義して指定します。

model = load_model('model.h5', custom_objects={'psnr':psnr})

これでmodelが使用可能になります。
もしかすると、modelをsaveするときにCustom Functionもまとめて保存できる方法があるのも...
ご存知の方がいらっしゃいましたらお教えいただけるとありがたいです。

7
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
3