LoginSignup
47
52

More than 5 years have passed since last update.

SSD(keras version2以上)の実装でエラーにはまったらここを直せ

Last updated at Posted at 2018-08-29

実装環境

Keras v2.2.0
Tensorflow v1.8.0
OpenCV v3.4.1

前提

画像データセット、アノテーションのpickleファイル、重みのhdf5ファイルはすべて準備万全という前提。

SSD(Single Shot MultiBox Detector) with Keras and Tensorflowについて

まず、SSDとは畳み込みニューラルネットワーク(CNN:Convolutional Neural Network)を用いた物体検出アルゴリズムで、300×300の入力に対して,VOC2007 testデータセットで74.3%のmAP1を達成している。
論文はこちらからhttps://arxiv.org/abs/1512.02325
といろいろ説明する前に、おそらく、このページに来てくれた読者諸君は、kerasのSSDの実装がうまくいなかい、エラーがでてしまうがどうやって解決すればいいのか困り果てている、だから解決法を簡単に教えてくれるサイトはないかなと調べて辿りついたかもしれない。
多くのサイトでは、keras(&tensorflow)フレームワークのSSDコードは、https://github.com/rykov8/ssd_keras が紹介されている。
ただし、Keras v1.2.2, Tensorflow v1.0.0, OpenCV v3.1.0-devで動作可能と謳っている。ローカルで実装する場合、Keras、Tensorflowのバージョンのダウングレードを余儀なくされているのが現状であろう。

ここでは上記のコードを、
Keras v2.2.0,
Tensorflowv1.8.0,
OpenCVv3.4.1
で動作可能にするため、修正箇所を以下に示す。

ssd.pyの修正箇所

kerasのupdateにより、ライブラリの「merge」が「concatenate」に変更になった。よって以下のように書き換える。

ssd.py
#以下の13行目
from keras.layers import merge 
# を下記の通り修正
from keras.layers import concatenate

# 282,290,298行目、そして316行目
merge => concatenate
#のように変更します。

#258,265,272,287行目 
mode='concat'
# は全て削除。

# 最後に
concat_axis=1 => axis=1
# へ変更

ssd_layers.pyの変更箇所

ssd_layers.py
# 111行目の
 def get_output_shape_for(self, input_shape):
# を以下に変更
 def compute_output_shape(self, input_shape):

SSD_training.ipynbの変更箇所

SSD_training.ipynb内でモデルの学習を実行するコードも、同様に変更。kerasバージョンアップでfit_generatorの引数変更のため。
https://keras.io/ja/models/sequential/

(誤)SSD_training.ipynb

nb_epoch = 30
history = model.fit_generator(gen.generate(True), gen.train_batches,
                              nb_epoch, verbose=1,
                              callbacks=callbacks,
                              validation_data=gen.generate(False),
                              nb_val_samples=gen.val_batches,
                              nb_worker=1)
(正)SSD_training.ipynb
epochs = 30
batch_size=16 #例
history = model.fit_generator(gen.generate(True), 
                     steps_per_epoch =gen.train_batches//batch_size,
                              epochs=epochs, verbose=1,
                              callbacks=callbacks,
                              validation_data=gen.generate(False),
                              validation_steps=gen.val_batches,
                              workers=1)

workersはデフォルトで1に設定されている。nb_workerのままでも大丈夫(だと思う)。

videotest.pyの変更部分

自ら準備した動画でDetectionを行う時に使う、testing_utilsフォルダ内のコード。
しかし、vedeotest.pyは、もともとOpenCVのV2用にコードが書かれている。V3対応にするため、以下のように一部変更する。

videotest.py
#87,88行目を下記に変更
     vidw = vid.get(cv2.CAP_PROP_FRAME_WIDTH)
     vidh = vid.get(cv2.CAP_PROP_FRAME_HEIGHT)

# 93行目を下記に変更
     vid.set(cv2.CAP_PROP_POS_MSEC, start_frame)

Good Luck!

47
52
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
47
52