##はじめ
仕事の関係でCNNで画像認識をしていました。普通分類問題が多いと思うが、今回は回帰です。その過程でぶつかったメモリ問題と解決の方法をメモします。
##目的
簡単に言うと、与えられた画像の中に魚の数をカウントする事です。
通常なやり方
Keras でCNNは非常に便利で、ネットワークのサンプルも多くて割愛します。画像をメモリに読み込み、その後下記のように一発で訓練すれば良い。
model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(X_val, y_val))
問題
画像数が少ない時にはよかったですが、多くなるとメモリに入りきれなくなって、エラーとなりました。
keras ならmodel.fitの代わりにmodel.fit_generatorを使えば、小出し(minibatchごと)に画像データを与えばいいのです。GeneratorならKerasのImageDataGeneratorがあって便利ですね。でも残念なことに今回は分類問題ではなく、回帰となっているので、どうやら目標の数値を与えるにはそのGeneratorが自作するしかないようです(本当はあるが自分が知らないだけかもしれませんが)。それで以下のようにGeneratorを作ってみました。
呼ぶ側:
subfolder = Path(args.fromdir)
pathlist = Path(subfolder).glob('**/*.jpg')
image_file_paths = list(pathlist)
image_file_paths = [str(item) for item in image_file_paths]
random.shuffle(image_file_paths)
all_file_len = len(image_file_paths)
test_pos = int(all_file_len * 0.1) # get 10% of data for test
val_pos = test_pos * 2 # get 10% of data for validation
test_gen = FishImageGenerator(image_file_paths[0:test_pos],batch_size,img_rows,img_cols,channels)
valid_gen = FishImageGenerator(image_file_paths[test_pos:val_pos],batch_size,img_rows,img_cols,channels)
train_gen = FishImageGenerator(image_file_paths[val_pos:],batch_size,img_rows,img_cols,channels)
model.fit_generator(
generator=train_gen,
epochs=epochs,
steps_per_epoch=len(train_gen),
verbose=1,
validation_data=valid_gen,
validation_steps=len(valid_gen),
callbacks=[es_cb,
TensorBoard(log_dir=log_dir),
ModelCheckpoint(model_path, save_best_only=True)]
)
Generator側:
class FishImageGenerator(Sequence):
def __init__(self,file_paths=None,batch_size = 16,img_rows=256, img_cols=256, channels=3):
self._file_paths = file_paths
self._batch_size = batch_size
self._img_rows = img_rows
self._img_cols = img_cols
self._channels = channels
if file_paths is None:
self._file_paths = self._get_paths()
self._array_len = len(self._file_paths)
self._num_batches_per_epoch = math.ceil(self._array_len / batch_size)
self.reset()
def reset(self):
random.shuffle(self._file_paths)
self.inner_targets = []
for path_in_str in self._file_paths:
path_obj = Path(path_in_str)
self.inner_targets.append(path_obj.parent._parts[-1])
def get_targets(self):
return self.inner_targets
def _get_paths(self,root_dir=None):
if root_dir is None:
root_dir = Path.cwd()
pathlist = Path(root_dir).glob('**/*.jpg')
return pathlist
def __getitem__(self, idx):
start_pos = self._batch_size * idx
end_pos = start_pos + self._batch_size
if end_pos > self._array_len:
end_pos = self._array_len
item_paths = self._file_paths[start_pos: end_pos]
loaded_data = []
target_label = []
for path_in_str in item_paths:
path_obj = Path(path_in_str)
try:
img = cv2.imread(path_in_str)
loaded_data.append(img)
target_label.append(path_obj.parent._parts[-1])
except Exception as e:
logger.error(e)
imgs = np.array(loaded_data).astype(np.float32)
imgs = imgs /255.
imgs = imgs.reshape(-1, self._img_rows, self._img_cols, self._channels)
targets = np.array(target_label).astype("uint8")
targets = targets.reshape(-1, 1)
return imgs, targets
def __len__(self):
"""Batch length"""
return self._num_batches_per_epoch
def on_epoch_end(self):
"""Task when end of epoch"""
print(read_pic.cache_info())
logger.info(read_pic.cache_info())
self.reset()
Generatorの書き方について、Kerasのドキュメントにもあるので、割愛します。書き方もいろいろですが、Sequenceから継承して、getitem、len、on_epoch_endをそれぞれ実装すれば良いです。それらのメソッドの目的はコメントしています。
試して見たところ、確かのメモリアウトのエラーは出なくなりましたが、毎回ディスクから読み込むのもどうもスッキリしなくて直ぐに思い浮かぶのはやはりCacheですよね。Python3ならfunctoolsにlru_cacheがあるので、それを使えば楽々じゃないですか。
import functools
@functools.lru_cache(maxsize=30000)
def read_pic(path_to_pic):
pic_data = cv2.imread(path_to_pic)
return pic_data
import functools と @functools.lru_cache の二行を追加することでCacheを使うことになりました。
で、なんか変ですね、print(read_pic.cache_info()) の一行で出力されたCacheのHit率があまりにも低いのではないか、maxsizeを拡大したら少しは改善しますが、本来もっと活躍してもらいたかったのに失望ですよ。
もう少し考えたら、なんだバカなのはこの自分じゃないですか。本来一Epochごとに全ての訓練データを使うのに、LRUだとすでに使ったデータをCacheに入れるのは逆効果に他にないよ。まぁ、どうせならCacheも自作してしまえば。
import functools
from pathlib import Path
import cv2
import numpy as np
import random
from keras.utils import Sequence
import math
from collections import namedtuple
from _thread import RLock
import logging
logger = logging.getLogger(__name__)
_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
def _make_key(args):
return hash(args[0])
def simple_cache(maxsize=128):
if maxsize is not None and not isinstance(maxsize, int):
raise TypeError('Expected maxsize to be an integer or None')
def decorating_function(user_function):
wrapper = _simple_cache_wrapper(user_function, maxsize, _CacheInfo)
return functools.update_wrapper(wrapper, user_function)
return decorating_function
def _simple_cache_wrapper(user_function, maxsize, _CacheInfo):
sentinel = object() # unique object used to signal cache misses
make_key = _make_key # build a key from the function arguments
cache = {}
hits = misses = 0
full = False
cache_get = cache.get # bound method to lookup a key or return None
cache_len = cache.__len__ # get cache size without calling len()
lock = RLock() # because linkedlist updates aren't threadsafe
if maxsize == 0 or maxsize is None:
def wrapper(*args, **kwds):
# No caching -- just a statistics update after a successful call
nonlocal misses
result = user_function(*args, **kwds)
misses += 1
return result
else:
def wrapper(*args, **kwds):
# Size limited caching that tracks accesses by recency
nonlocal hits, misses, full
key = make_key(args)
with lock: # Python's built-in structures are thread-safe for single operations but
# a lock here will add almost no overhead, and will give me peace of mind.
result = cache_get(key)
if result is not None:
hits += 1
return result
result = user_function(*args, **kwds)
with lock:
if key in cache:
# Getting here means that this same key was added to the
# cache while the lock was released. Since the link
# update is already done, we need only return the
# computed result and update the count of misses.
pass
elif full:
# our simple is full,so just ignore the coming one
pass
else:
cache[key] = result
full = (cache_len() >= maxsize)
misses += 1
return result
def cache_info():
"""Report cache statistics"""
with lock:
return _CacheInfo(hits, misses, maxsize, cache_len())
def cache_clear():
"""Clear the cache and cache statistics"""
nonlocal hits, misses, full
with lock:
cache.clear()
hits = misses = 0
full = False
wrapper.cache_info = cache_info
wrapper.cache_clear = cache_clear
return wrapper
@simple_cache(maxsize=2000)
def read_pic(path_to_pic):
pic_data = cv2.imread(path_to_pic)
return pic_data
class FishImageGenerator(Sequence):
def __init__(self,file_paths=None,batch_size = 16,img_rows=256, img_cols=256, channels=3):
self._file_paths = file_paths
self._batch_size = batch_size
self._img_rows = img_rows
self._img_cols = img_cols
self._channels = channels
if file_paths is None:
self._file_paths = self._get_paths()
self._array_len = len(self._file_paths)
self._num_batches_per_epoch = math.ceil(self._array_len / batch_size)
self.reset()
def reset(self):
random.shuffle(self._file_paths)
self.inner_targets = []
for path_in_str in self._file_paths:
path_obj = Path(path_in_str)
self.inner_targets.append(path_obj.parent._parts[-1])
def get_targets(self):
return self.inner_targets
def _get_paths(self,root_dir=None):
if root_dir is None:
root_dir = Path.cwd()
pathlist = Path(root_dir).glob('**/*.jpg')
return pathlist
def __getitem__(self, idx):
start_pos = self._batch_size * idx
end_pos = start_pos + self._batch_size
if end_pos > self._array_len:
end_pos = self._array_len
item_paths = self._file_paths[start_pos: end_pos]
loaded_data = []
target_label = []
for path_in_str in item_paths:
path_obj = Path(path_in_str)
try:
img = read_pic(path_in_str)
#img = cv2.imread(path_in_str)
loaded_data.append(img)
target_label.append(path_obj.parent._parts[-1])
except Exception as e:
print(e)
logger.error(e)
imgs = np.array(loaded_data).astype(np.float32)
imgs = imgs /255.
imgs = imgs.reshape(-1, self._img_rows, self._img_cols, self._channels)
targets = np.array(target_label).astype("uint8")
targets = targets.reshape(-1, 1)
#self.inner_imgs,self.inner_targets = img,targets
return imgs, targets
def __len__(self):
"""Batch length"""
return self._num_batches_per_epoch
def on_epoch_end(self):
"""Task when end of epoch"""
print(read_pic.cache_info())
logger.info(read_pic.cache_info())
self.reset()
やっとこれで少しスッキリしました。でもメモリの消耗が画像サイズの十倍もあるようで、それではCacheの効果もだいぶ落ちますね。よく考えたら、本来圧縮された画像ファイルをNumpyのndarray形式に展開された時点で膨らむよね、それじゃあ当然Cacheできる数も減りますね。
pic_data = cv2.imread(path_to_pic)
with open(path_to_pic, 'rb') as f:
binary = f.read()
import sys
print ("array size is {}".format(sys.getsizeof(pic_data)))
print ("binary size is {}".format(sys.getsizeof(binary)))
測ってみたら、やはり20倍ぐらいの差がありますね。array size is 196736
binary size is 7833
それで、前のソースの中の
def read_pic(path_to_pic):
pic_data = cv2.imread(path_to_pic)
return pic_data
を下記のように直せば完成
@simple_cache(maxsize=150000)
def read_pic(path_to_pic):
#pic_data = cv2.imread(path_to_pic)
with open(path_to_pic,'rb') as f:
binary = f.read()
return binary
def bin2arr(binary):
arr = np.asarray(bytearray(binary), dtype=np.uint8)
pic_data = cv2.imdecode(arr, -1)
return pic_data
色々失敗して、やっと形になりました。でもそれで毎回画像ファイルの解凍にもCPU時間を費やさなければならないことになります、その辺の調整は訓練の状況をみてするしかないようです。
あ、言い忘れたところです、自作のCacheのソースの大部分はPythonのfunctoolsからパクったものです。