実装環境
Keras v2.2.0
Tensorflow v1.8.0
OpenCV v3.4.1
前提
画像データセット、アノテーションのpickleファイル、重みのhdf5ファイルはすべて準備万全という前提。
SSD(Single Shot MultiBox Detector) with Keras and Tensorflowについて
まず、SSDとは畳み込みニューラルネットワーク(CNN:Convolutional Neural Network)を用いた物体検出アルゴリズムで、300×300の入力に対して,VOC2007 testデータセットで74.3%のmAP1を達成している。
論文はこちらからhttps://arxiv.org/abs/1512.02325
といろいろ説明する前に、おそらく、このページに来てくれた読者諸君は、kerasのSSDの実装がうまくいなかい、エラーがでてしまうがどうやって解決すればいいのか困り果てている、だから解決法を簡単に教えてくれるサイトはないかなと調べて辿りついたかもしれない。
多くのサイトでは、keras(&tensorflow)フレームワークのSSDコードは、https://github.com/rykov8/ssd_keras が紹介されている。
ただし、Keras v1.2.2, Tensorflow v1.0.0, OpenCV v3.1.0-devで動作可能と謳っている。ローカルで実装する場合、Keras、Tensorflowのバージョンのダウングレードを余儀なくされているのが現状であろう。
ここでは上記のコードを、
Keras v2.2.0,
Tensorflowv1.8.0,
OpenCVv3.4.1
で動作可能にするため、修正箇所を以下に示す。
ssd.pyの修正箇所
kerasのupdateにより、ライブラリの「merge」が「concatenate」に変更になった。よって以下のように書き換える。
#以下の13行目
from keras.layers import merge
# を下記の通り修正
from keras.layers import concatenate
# 282,290,298行目、そして316行目
merge => concatenate
#のように変更します。
#258,265,272,287行目
mode='concat'
# は全て削除。
# 最後に
concat_axis=1 => axis=1
# へ変更
ssd_layers.pyの変更箇所
# 111行目の
def get_output_shape_for(self, input_shape):
# を以下に変更
def compute_output_shape(self, input_shape):
SSD_training.ipynbの変更箇所
SSD_training.ipynb内でモデルの学習を実行するコードも、同様に変更。kerasバージョンアップでfit_generatorの引数変更のため。
https://keras.io/ja/models/sequential/
nb_epoch = 30
history = model.fit_generator(gen.generate(True), gen.train_batches,
nb_epoch, verbose=1,
callbacks=callbacks,
validation_data=gen.generate(False),
nb_val_samples=gen.val_batches,
nb_worker=1)
epochs = 30
batch_size=16 #例
history = model.fit_generator(gen.generate(True),
steps_per_epoch =gen.train_batches//batch_size,
epochs=epochs, verbose=1,
callbacks=callbacks,
validation_data=gen.generate(False),
validation_steps=gen.val_batches,
workers=1)
workersはデフォルトで1に設定されている。nb_workerのままでも大丈夫(だと思う)。
videotest.pyの変更部分
自ら準備した動画でDetectionを行う時に使う、testing_utilsフォルダ内のコード。
しかし、vedeotest.pyは、もともとOpenCVのV2用にコードが書かれている。V3対応にするため、以下のように一部変更する。
#87,88行目を下記に変更
vidw = vid.get(cv2.CAP_PROP_FRAME_WIDTH)
vidh = vid.get(cv2.CAP_PROP_FRAME_HEIGHT)
# 93行目を下記に変更
vid.set(cv2.CAP_PROP_POS_MSEC, start_frame)
Good Luck!