2
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

CNNで画像認識にメモリが足りない時

Last updated at Posted at 2018-11-19

##はじめ
  仕事の関係でCNNで画像認識をしていました。普通分類問題が多いと思うが、今回は回帰です。その過程でぶつかったメモリ問題と解決の方法をメモします。

##目的
簡単に言うと、与えられた画像の中に魚の数をカウントする事です。

通常なやり方

 Keras でCNNは非常に便利で、ネットワークのサンプルも多くて割愛します。画像をメモリに読み込み、その後下記のように一発で訓練すれば良い。

cnn_re.py
model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(X_val, y_val))

問題

画像数が少ない時にはよかったですが、多くなるとメモリに入りきれなくなって、エラーとなりました。
keras ならmodel.fitの代わりにmodel.fit_generatorを使えば、小出し(minibatchごと)に画像データを与えばいいのです。GeneratorならKerasのImageDataGeneratorがあって便利ですね。でも残念なことに今回は分類問題ではなく、回帰となっているので、どうやら目標の数値を与えるにはそのGeneratorが自作するしかないようです(本当はあるが自分が知らないだけかもしれませんが)。それで以下のようにGeneratorを作ってみました。

呼ぶ側:

cnn_re.py
subfolder = Path(args.fromdir)
pathlist = Path(subfolder).glob('**/*.jpg')
image_file_paths = list(pathlist)
image_file_paths = [str(item) for item in image_file_paths]

random.shuffle(image_file_paths)
all_file_len = len(image_file_paths)
test_pos = int(all_file_len * 0.1)  # get 10% of data for test
val_pos = test_pos * 2 # get 10% of data for validation

test_gen = FishImageGenerator(image_file_paths[0:test_pos],batch_size,img_rows,img_cols,channels)
valid_gen = FishImageGenerator(image_file_paths[test_pos:val_pos],batch_size,img_rows,img_cols,channels)
train_gen = FishImageGenerator(image_file_paths[val_pos:],batch_size,img_rows,img_cols,channels)

model.fit_generator(
    generator=train_gen,
    epochs=epochs,
    steps_per_epoch=len(train_gen),
    verbose=1,
    validation_data=valid_gen,
    validation_steps=len(valid_gen),
    callbacks=[es_cb,
              TensorBoard(log_dir=log_dir),
              ModelCheckpoint(model_path, save_best_only=True)]
    )

Generator側:

cnn_re.py
class FishImageGenerator(Sequence):
    def __init__(self,file_paths=None,batch_size = 16,img_rows=256, img_cols=256, channels=3):
        self._file_paths = file_paths
        self._batch_size = batch_size
        self._img_rows = img_rows
        self._img_cols = img_cols
        self._channels = channels

        if file_paths is None:
            self._file_paths = self._get_paths()

        self._array_len = len(self._file_paths)
        self._num_batches_per_epoch = math.ceil(self._array_len / batch_size)
        self.reset()

    def reset(self):
        random.shuffle(self._file_paths)
        self.inner_targets = []
        for path_in_str in self._file_paths:
            path_obj = Path(path_in_str)
            self.inner_targets.append(path_obj.parent._parts[-1])

    def get_targets(self):
        return self.inner_targets

    def _get_paths(self,root_dir=None):
        if root_dir is None:
            root_dir = Path.cwd()
            pathlist = Path(root_dir).glob('**/*.jpg')

        return pathlist

    def __getitem__(self, idx):
        start_pos = self._batch_size * idx
        end_pos = start_pos + self._batch_size
        if end_pos > self._array_len:
            end_pos = self._array_len
        item_paths = self._file_paths[start_pos: end_pos]

        loaded_data = []
        target_label = []

        for path_in_str in item_paths:
            path_obj = Path(path_in_str)
            try:
                img = cv2.imread(path_in_str)
                loaded_data.append(img)
                target_label.append(path_obj.parent._parts[-1])
            except Exception as e:
                logger.error(e)

        imgs = np.array(loaded_data).astype(np.float32)
        imgs = imgs /255.
        imgs = imgs.reshape(-1, self._img_rows, self._img_cols, self._channels)
        targets = np.array(target_label).astype("uint8")
        targets = targets.reshape(-1, 1)

        return imgs, targets

    def __len__(self):
        """Batch length"""
        return self._num_batches_per_epoch

    def on_epoch_end(self):
        """Task when end of epoch"""
        print(read_pic.cache_info())
        logger.info(read_pic.cache_info())
        self.reset()

Generatorの書き方について、Kerasのドキュメントにもあるので、割愛します。書き方もいろいろですが、Sequenceから継承して、getitemlen、on_epoch_endをそれぞれ実装すれば良いです。それらのメソッドの目的はコメントしています。

試して見たところ、確かのメモリアウトのエラーは出なくなりましたが、毎回ディスクから読み込むのもどうもスッキリしなくて直ぐに思い浮かぶのはやはりCacheですよね。Python3ならfunctoolsにlru_cacheがあるので、それを使えば楽々じゃないですか。

generator.py

import functools
@functools.lru_cache(maxsize=30000)
def read_pic(path_to_pic):
    pic_data = cv2.imread(path_to_pic)
    return pic_data

import functools と @functools.lru_cache の二行を追加することでCacheを使うことになりました。

で、なんか変ですね、print(read_pic.cache_info()) の一行で出力されたCacheのHit率があまりにも低いのではないか、maxsizeを拡大したら少しは改善しますが、本来もっと活躍してもらいたかったのに失望ですよ。
もう少し考えたら、なんだバカなのはこの自分じゃないですか。本来一Epochごとに全ての訓練データを使うのに、LRUだとすでに使ったデータをCacheに入れるのは逆効果に他にないよ。まぁ、どうせならCacheも自作してしまえば。

generator.py

import functools
from pathlib import Path
import cv2
import numpy as np
import random
from keras.utils import Sequence
import math
from collections import namedtuple
from _thread import RLock

import logging
logger = logging.getLogger(__name__)

_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])

def _make_key(args):
    return hash(args[0])

def simple_cache(maxsize=128):
    if maxsize is not None and not isinstance(maxsize, int):
        raise TypeError('Expected maxsize to be an integer or None')

    def decorating_function(user_function):
        wrapper = _simple_cache_wrapper(user_function, maxsize, _CacheInfo)
        return functools.update_wrapper(wrapper, user_function)

    return decorating_function

def _simple_cache_wrapper(user_function, maxsize, _CacheInfo):
    sentinel = object()          # unique object used to signal cache misses
    make_key = _make_key         # build a key from the function arguments

    cache = {}
    hits = misses = 0
    full = False
    cache_get = cache.get    # bound method to lookup a key or return None
    cache_len = cache.__len__  # get cache size without calling len()
    lock = RLock()           # because linkedlist updates aren't threadsafe

    if maxsize == 0 or maxsize is None:

        def wrapper(*args, **kwds):
            # No caching -- just a statistics update after a successful call
            nonlocal misses
            result = user_function(*args, **kwds)
            misses += 1
            return result

    else:

        def wrapper(*args, **kwds):
            # Size limited caching that tracks accesses by recency
            nonlocal hits, misses, full
            key = make_key(args)
            with lock:  # Python's built-in structures are thread-safe for single operations but
                # a lock here will add almost no overhead, and will give me peace of mind.
                result = cache_get(key)
                if result is not None:
                    hits += 1
                    return result
            result = user_function(*args, **kwds)
            with lock:
                if key in cache:
                    # Getting here means that this same key was added to the
                    # cache while the lock was released.  Since the link
                    # update is already done, we need only return the
                    # computed result and update the count of misses.
                    pass
                elif full:
                    # our simple is  full,so just ignore the coming one
                    pass
                else:
                    cache[key] = result
                    full = (cache_len() >= maxsize)
                misses += 1
            return result

    def cache_info():
        """Report cache statistics"""
        with lock:
            return _CacheInfo(hits, misses, maxsize, cache_len())

    def cache_clear():
        """Clear the cache and cache statistics"""
        nonlocal hits, misses, full
        with lock:
            cache.clear()
            hits = misses = 0
            full = False

    wrapper.cache_info = cache_info
    wrapper.cache_clear = cache_clear
    return wrapper

@simple_cache(maxsize=2000)
def read_pic(path_to_pic):
    pic_data = cv2.imread(path_to_pic)
    return pic_data

class FishImageGenerator(Sequence):
    def __init__(self,file_paths=None,batch_size = 16,img_rows=256, img_cols=256, channels=3):
        self._file_paths = file_paths
        self._batch_size = batch_size
        self._img_rows = img_rows
        self._img_cols = img_cols
        self._channels = channels

        if file_paths is None:
            self._file_paths = self._get_paths()

        self._array_len = len(self._file_paths)
        self._num_batches_per_epoch = math.ceil(self._array_len / batch_size)
        self.reset()

    def reset(self):
        random.shuffle(self._file_paths)
        self.inner_targets = []
        for path_in_str in self._file_paths:
            path_obj = Path(path_in_str)
            self.inner_targets.append(path_obj.parent._parts[-1])

    def get_targets(self):
        return self.inner_targets

    def _get_paths(self,root_dir=None):
        if root_dir is None:
            root_dir = Path.cwd()
            pathlist = Path(root_dir).glob('**/*.jpg')

        return pathlist

    def __getitem__(self, idx):
        start_pos = self._batch_size * idx
        end_pos = start_pos + self._batch_size
        if end_pos > self._array_len:
            end_pos = self._array_len
        item_paths = self._file_paths[start_pos: end_pos]

        loaded_data = []
        target_label = []

        for path_in_str in item_paths:
            path_obj = Path(path_in_str)
            try:
                img = read_pic(path_in_str)
                #img = cv2.imread(path_in_str)
                loaded_data.append(img)
                target_label.append(path_obj.parent._parts[-1])
            except Exception as e:
                print(e)
                logger.error(e)

        imgs = np.array(loaded_data).astype(np.float32)
        imgs = imgs /255.
        imgs = imgs.reshape(-1, self._img_rows, self._img_cols, self._channels)
        targets = np.array(target_label).astype("uint8")
        targets = targets.reshape(-1, 1)

        #self.inner_imgs,self.inner_targets = img,targets
        return imgs, targets

    def __len__(self):
        """Batch length"""
        return self._num_batches_per_epoch

    def on_epoch_end(self):
        """Task when end of epoch"""
        print(read_pic.cache_info())
        logger.info(read_pic.cache_info())
        self.reset()

やっとこれで少しスッキリしました。でもメモリの消耗が画像サイズの十倍もあるようで、それではCacheの効果もだいぶ落ちますね。よく考えたら、本来圧縮された画像ファイルをNumpyのndarray形式に展開された時点で膨らむよね、それじゃあ当然Cacheできる数も減りますね。

test.py
    pic_data = cv2.imread(path_to_pic)
    with open(path_to_pic, 'rb') as f:
        binary = f.read()
       
    import sys
    print ("array size is {}".format(sys.getsizeof(pic_data)))
    print ("binary size is {}".format(sys.getsizeof(binary)))
    

測ってみたら、やはり20倍ぐらいの差がありますね。array size is 196736
binary size is 7833

それで、前のソースの中の

generator.py
def read_pic(path_to_pic):
    pic_data = cv2.imread(path_to_pic)
    return pic_data

を下記のように直せば完成

generator.py
@simple_cache(maxsize=150000)
def read_pic(path_to_pic):
    #pic_data = cv2.imread(path_to_pic)
    with open(path_to_pic,'rb') as f:
        binary = f.read()
    return binary

def bin2arr(binary):
    arr = np.asarray(bytearray(binary), dtype=np.uint8)
    pic_data = cv2.imdecode(arr, -1)  

    return pic_data

色々失敗して、やっと形になりました。でもそれで毎回画像ファイルの解凍にもCPU時間を費やさなければならないことになります、その辺の調整は訓練の状況をみてするしかないようです。

あ、言い忘れたところです、自作のCacheのソースの大部分はPythonのfunctoolsからパクったものです。

2
6
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?