LoginSignup
7
7

More than 3 years have passed since last update.

TRN-pytorch をローカルで動かしてみた

Posted at

はじめに

人の動きを動画から判別したかったので、試しにTRN-pytorchをローカルで動かしたかったが、躓いてしまったので、参考になればと。
実際Google ColabとかGPUがある環境で動かせばすぐに動いたんじゃないかなとは思うが。。。

準備

gitからTRN-pytorchをダウンロードして、必要なモジュールをインストール

git clone --recursive https://github.com/metalbubble/TRN-pytorch
cd TRN-pytorch/pretrain
./download_models.sh
cd ../sample_data
./download_sample_data.sh

実行

いざ実行

./test_video.sh

動かない...。

RuntimeError: Error(s) in loading state_dict for InceptionV3:
        size mismatch for conv_batchnorm.weight: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([32]).
        size mismatch for conv_batchnorm.bias: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([32]).
        size mismatch for conv_batchnorm.running_mean: copying a param with shape torch.Size([1, 32]) from checkpoint, the shape in current model is torch.Size([32]).
        ...

長々とsize mismatchと出ます。調べたら以下に載っていました(https://github.com/metalbubble/TRN-pytorch/issues/10) 。
まず、ここからダウンロードします。

ダウンロードしたものをTRN-pytorch/model_zoo/bninception/に配置し、TRN-pytorch/model_zoo/bninception/pytorch_load.py の35行目を修正します。

pytorch_load.py
# self.load_state_dict(torch.utils.model_zoo.load_url(weight_url))
new_state_dict = {}
temp2 = torch.load('model_zoo/bninception/bn_inception-52deb4733.pth')
for k, v in temp2.items():
    if(k.split(".")[0] == 'last_linear'):
        new_state_dict['fc.' + k.split(".")[1]] = v
    else:
        new_state_dict[k] = v
self.load_state_dict(new_state_dict, strict=False)

修正し、実行

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

エラーは変わった。ローカルなのでcudaは確かにいらない。エラー通りに修正。
今回はTRN-pytorch配下のtest_video.py の107行目を修正。

test_video.py
checkpoint = torch.load(args.weights, map_location='cpu')

そろそろ...

AssertionError: Torch not compiled with CUDA enabled

はい、cudaね。test_video.pyの110行目と134行目に.cuda()があるので、.cuda()だけを削除し、再び実行。

Multi-Scale Temporal Relation Network Module in use ['8-frame relation', '7-frame relation', '6-frame relation', '5-frame relation', '4-frame relation', '3-frame relation', '2-frame relation']
Freezing BatchNorm2D except the first one.
Loading frames in sample_data/juggling_frames
RESULT ON sample_data/juggling_frames
1.000 -> juggling
0.000 -> catching
0.000 -> balancing
0.000 -> spinning
0.000 -> performing

結果が出力されましたね。
Google Colabとかでやれば、はじめの修正だけでいけると思います。後日やってみたいと思います。

7
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
7