#背景
- Keras (Tensorflow backend)で一台のパソコンに載っている複数GPUでトレーニング処理を回すとき、中間結果(チェックポイント)のモデルを、keras.callbacks.ModelCheckpoint関数をコールバック関数に設定して定期保存しようとしたとき、以下のエラーが発生。
Epoch 01000: saving model to ./log_Stacked_Hourglass_Network_v1_320x240_fp16_3rd_multiGPU/weights.1000-0.0039607391-0.0131516113.hdf5
Traceback (most recent call last):
File "train_customized_SHN_resume_baobab_fp16_1stack_multiGPU.py", line 304, in
callbacks=[cp_cb], validation_data=(X_val,[y_val]), shuffle=False)
...(中略)...
File "/usr/lib/python3.5/copy.py", line 306, in _reconstruct
y.dict.update(state)
AttributeError: 'NoneType' object has no attribute 'update'
- よく調べたところ、一周回ってKerasのドキュメンテーションにたどり着き、以下のようなことが書いてあった。
マルチGPUのモデルを保存するには,multi_gpu_modelの返り値のモデルではなく,テンプレートになった(multi_gpu_modelの引数として渡した)モデルで.save(fname)か.save_weights(fname)を使ってください.
- つまり、マルチGPUにする場合は、中間結果を保存するコールバック関数は、Kerasでもとから用意されているやつじゃなくて、自分で用意しろってことですね!
#KerasのマルチGPU処理で中間結果のモデルを保存するコールバック関数のコード例
- 以下のサイトに手本となるコールバック関数があった。
- 1つのファイルにまとめ、(あとちょっと修正して)使えるようにしました!以下参照。
- よかったらimportして使ってください。
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import csv
import six
import numpy as np
import time
import json
import warnings
# from collections import deque
# from collections import OrderedDict
# from collections import Iterable
# from .utils.generic_utils import Progbar
# from . import backend as K
# from .engine.topology import Layer
try:
import requests
except ImportError:
requests = None
class Callback(object):
"""Abstract base class used to build new callbacks.
# Properties
params: dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
model: instance of `keras.models.Model`.
Reference of the model being trained.
The `logs` dictionary that callback methods
take as argument will contain keys for quantities relevant to
the current batch or epoch.
Currently, the `.fit()` method of the `Sequential` model class
will include the following quantities in the `logs` that
it passes to its callbacks:
on_epoch_end: logs include `acc` and `loss`, and
optionally include `val_loss`
(if validation is enabled in `fit`), and `val_acc`
(if validation and accuracy monitoring are enabled).
on_batch_begin: logs include `size`,
the number of samples in the current batch.
on_batch_end: logs include `loss`, and optionally `acc`
(if accuracy monitoring is enabled).
"""
def __init__(self):
self.validation_data = None
self.model = None
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
def on_epoch_begin(self, epoch, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
pass
def on_batch_begin(self, batch, logs=None):
pass
def on_batch_end(self, batch, logs=None):
pass
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
# ==========================================================================
#
# Multi-GPU Model Save Callback
#
# ==========================================================================
class MultiGPUCheckpointCallback(Callback):
def __init__(self, filepath, base_model, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
mode='auto', period=1):
super(MultiGPUCheckpointCallback, self).__init__()
self.base_model = base_model
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn('ModelCheckpoint mode %s is unknown, '
'fallback to auto mode.' % (mode),
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('Epoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.base_model.save_weights(filepath, overwrite=True)
else:
self.base_model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('Epoch %05d: %s did not improve' %
(epoch + 1, self.monitor))
else:
if self.verbose > 0:
print('Epoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.base_model.save_weights(filepath, overwrite=True)
else:
self.base_model.save(filepath, overwrite=True)