14
7

More than 5 years have passed since last update.

Kerasで複数GPUで学習するときに中間結果モデル保存が失敗する問題の対応

Last updated at Posted at 2018-09-02

背景

  • Keras (Tensorflow backend)で一台のパソコンに載っている複数GPUでトレーニング処理を回すとき、中間結果(チェックポイント)のモデルを、keras.callbacks.ModelCheckpoint関数をコールバック関数に設定して定期保存しようとしたとき、以下のエラーが発生。

Epoch 01000: saving model to ./log_Stacked_Hourglass_Network_v1_320x240_fp16_3rd_multiGPU/weights.1000-0.0039607391-0.0131516113.hdf5
Traceback (most recent call last):
File "train_customized_SHN_resume_baobab_fp16_1stack_multiGPU.py", line 304, in
callbacks=[cp_cb], validation_data=(X_val,[y_val]), shuffle=False)
...(中略)...
File "/usr/lib/python3.5/copy.py", line 306, in reconstruct
y.
dict_.update(state)
AttributeError: 'NoneType' object has no attribute 'update'

  • よく調べたところ、一周回ってKerasのドキュメンテーションにたどり着き、以下のようなことが書いてあった。

マルチGPUのモデルを保存するには,multi_gpu_modelの返り値のモデルではなく,テンプレートになった(multi_gpu_modelの引数として渡した)モデルで.save(fname)か.save_weights(fname)を使ってください.

  • つまり、マルチGPUにする場合は、中間結果を保存するコールバック関数は、Kerasでもとから用意されているやつじゃなくて、自分で用意しろってことですね!

KerasのマルチGPU処理で中間結果のモデルを保存するコールバック関数のコード例

  • 以下のサイトに手本となるコールバック関数があった。
  • 1つのファイルにまとめ、(あとちょっと修正して)使えるようにしました!以下参照。
    • よかったらimportして使ってください。
python3:multiGPUCheckPointCallback.py

#

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import csv
import six

import numpy as np
import time
import json
import warnings

# from collections import deque
# from collections import OrderedDict
# from collections import Iterable
# from .utils.generic_utils import Progbar
# from . import backend as K
# from .engine.topology import Layer

try:
    import requests
except ImportError:
    requests = None


class Callback(object):
    """Abstract base class used to build new callbacks.

    # Properties
        params: dict. Training parameters
            (eg. verbosity, batch size, number of epochs...).
        model: instance of `keras.models.Model`.
            Reference of the model being trained.

    The `logs` dictionary that callback methods
    take as argument will contain keys for quantities relevant to
    the current batch or epoch.

    Currently, the `.fit()` method of the `Sequential` model class
    will include the following quantities in the `logs` that
    it passes to its callbacks:

        on_epoch_end: logs include `acc` and `loss`, and
            optionally include `val_loss`
            (if validation is enabled in `fit`), and `val_acc`
            (if validation and accuracy monitoring are enabled).
        on_batch_begin: logs include `size`,
            the number of samples in the current batch.
        on_batch_end: logs include `loss`, and optionally `acc`
            (if accuracy monitoring is enabled).
    """

    def __init__(self):
        self.validation_data = None
        self.model = None

    def set_params(self, params):
        self.params = params

    def set_model(self, model):
        self.model = model

    def on_epoch_begin(self, epoch, logs=None):
        pass

    def on_epoch_end(self, epoch, logs=None):
        pass

    def on_batch_begin(self, batch, logs=None):
        pass

    def on_batch_end(self, batch, logs=None):
        pass

    def on_train_begin(self, logs=None):
        pass

    def on_train_end(self, logs=None):
        pass


# ==========================================================================
#
#  Multi-GPU Model Save Callback
#
# ==========================================================================



class MultiGPUCheckpointCallback(Callback):

    def __init__(self, filepath, base_model, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
        super(MultiGPUCheckpointCallback, self).__init__()
        self.base_model = base_model
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_save = 0

        if mode not in ['auto', 'min', 'max']:
            warnings.warn('ModelCheckpoint mode %s is unknown, '
                          'fallback to auto mode.' % (mode),
                          RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            filepath = self.filepath.format(epoch=epoch + 1, **logs)
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    warnings.warn('Can save best model only with %s available, '
                                  'skipping.' % (self.monitor), RuntimeWarning)
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            print('Epoch %05d: %s improved from %0.5f to %0.5f,'
                                  ' saving model to %s'
                                  % (epoch + 1, self.monitor, self.best,
                                     current, filepath))
                        self.best = current
                        if self.save_weights_only:
                            self.base_model.save_weights(filepath, overwrite=True)
                        else:
                            self.base_model.save(filepath, overwrite=True)
                    else:
                        if self.verbose > 0:
                            print('Epoch %05d: %s did not improve' %
                                  (epoch + 1, self.monitor))
            else:
                if self.verbose > 0:
                    print('Epoch %05d: saving model to %s' % (epoch + 1, filepath))
                if self.save_weights_only:
                    self.base_model.save_weights(filepath, overwrite=True)
                else:
                    self.base_model.save(filepath, overwrite=True)

14
7
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
7