概要
ネットに落ちてたSSDの実装ではまってしまい、それの解決方法を備忘録として残します。
私と同じようにSSDではまってしまった方の力になれれば幸いです。
動かしたいコード
githubに公開されていたSSDのリポジトリ
https://github.com/YutaroOgawa/pytorch_advanced/tree/master/2_objectdetection
これを.ipynbから.pyの形式に変更して学習および物体検知の実行したい
特にこの中の2_7_ssd_training.ipynb(.py)がエラーとなり実行できなかった
ざっくり動作環境
- python 3.12
- pytorch 2.3
- numpy 1.26
エラー箇所
エラーメッセージは記録していなかったが、2_7_training.pyの実行時にデータ拡張プログラム(utils内のdata_augumentation.py)が呼び出され、その中のnumpy.random.choice()の部分でエラーとなっていた。
考えられる原因の1つは
numpy.random.choice()の引数でndarrayやint配列以外のオブジェクトが渡されたからだと思われる。おそらく関数が呼び出される以前のどこかの関数が仕様変更され結果的にndarrayやint配列以外のオブジェクトを返していた?
もしくは単純にnumpy.random.choiceの仕様がバージョン変更によって変わっていた?でも公式ドキュメント見てもそんなこと書いてないように見える…
コード修正
utile/data_augumentation.py内246行目付近
こいつが原因なのでコメントアウトもしくは削除
mode = random.choice(self.sample_options) # 削除
以下のコードを追加
from random import choice
mode = choice(self.sample_options)
numpyのrandom.choiceではなく、python標準搭載のrandomモジュールのchoice関数を使用することでエラーを回避できた。
np.random.choiceと標準搭載random.choiceの違い
両関数は引数として与えられたデータからランダムな要素を選択して返すという点では一緒である。
しかし以下のような相違点がある
-
numpy.random.choice()
- 引数としてndarray(numpy配列)もしくはint型配列を受け付ける
- ndarrayに最適化されているため処理は早い
-
random.choice()
- 引数としてリストやタプル、文字列などのシーケンスに加えてndarrayもサポートされておりインデックスで要素を指定してくれるオブジェクトを受け付ける
- つまり自由度が高い関数といえる
おまけ
今回私はローカルのanaconda下で.pyの形式で実行した。.pyとjupyter形式の.ipynbも若干のコード変更が必要なのでそれも記載する。jupyter形式での実行を前提としている方はここでは対象外である。
まずjupyter形式(google colab含む)だとmatplotlibで作図する際、基本的にはplt.show()は不要である。しかし.py形式で一括実行する際はplt.show()は必須である。
そのためリポジトリ内の物体検知実行プログラム2_8_ssd_inference.pyを実行してもエラーは出ないが検出画像も表示されない。
なのでutils/ssd_predict_show.py内、class SSDPredictShowの vis_bbox()で一番最後(163行目あたり)に、plt.show()を加える。
class SSDPredictShow():
・・・
def vis_bbox(self, rgb_img, bbox, label_index, scores, label_names):
・・・
plt.show() # <-追加
参考