はじめに
人の動きを動画から判別したかったので、試しにTRN-pytorchをローカルで動かしたかったが、躓いてしまったので、参考になればと。
実際Google ColabとかGPUがある環境で動かせばすぐに動いたんじゃないかなとは思うが。。。
準備
gitからTRN-pytorchをダウンロードして、必要なモジュールをインストール
git clone --recursive https://github.com/metalbubble/TRN-pytorch
cd TRN-pytorch/pretrain
./download_models.sh
cd ../sample_data
./download_sample_data.sh
実行
いざ実行
./test_video.sh
動かない...。
RuntimeError: Error(s) in loading state_dict for InceptionV3:
size mismatch for conv_batchnorm.weight: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([32]).
size mismatch for conv_batchnorm.bias: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([32]).
size mismatch for conv_batchnorm.running_mean: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([32]).
...
長々とsize mismatchと出ます。調べたら以下に載っていました(https://github.com/metalbubble/TRN-pytorch/issues/10) 。
まず、[ここ][link-1]からダウンロードします。
[link-1]:http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pt
ダウンロードしたものをTRN-pytorch/model_zoo/bninception/に配置し、TRN-pytorch/model_zoo/bninception/pytorch_load.py の35行目を修正します。
# self.load_state_dict(torch.utils.model_zoo.load_url(weight_url))
new_state_dict = {}
temp2 = torch.load('model_zoo/bninception/bn_inception-52deb4733.pth')
for k, v in temp2.items():
if(k.split(".")[0] == 'last_linear'):
new_state_dict['fc.' + k.split(".")[1]] = v
else:
new_state_dict[k] = v
self.load_state_dict(new_state_dict, strict=False)
修正し、実行
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.
エラーは変わった。ローカルなのでcudaは確かにいらない。エラー通りに修正。
今回はTRN-pytorch配下のtest_video.py の107行目を修正。
checkpoint = torch.load(args.weights, map_location='cpu')
そろそろ...
AssertionError: Torch not compiled with CUDA enabled
はい、cudaね。test_video.pyの110行目と134行目に.cuda()があるので、.cuda()だけを削除し、再び実行。
Multi-Scale Temporal Relation Network Module in use ['8-frame relation', '7-frame relation', '6-frame relation', '5-frame relation', '4-frame relation', '3-frame relation', '2-frame relation']
Freezing BatchNorm2D except the first one.
Loading frames in sample_data/juggling_frames
RESULT ON sample_data/juggling_frames
1.000 -> juggling
0.000 -> catching
0.000 -> balancing
0.000 -> spinning
0.000 -> performing
結果が出力されましたね。
Google Colabとかでやれば、はじめの修正だけでいけると思います。後日やってみたいと思います。