4
1

More than 1 year has passed since last update.

オプティカルフローのニューラルネットワークモデルRAFTをGoogle Colaboratoryで動かしてみる。

Last updated at Posted at 2022-06-02

Google CoraboratoryでGPUが12GBまで使えると聞いて、RAFT割と動かせるんじゃない?思って調べてみた。
フルHDとかでギリなんじゃないかと思う。面倒くさくて試していない。
どこに需要があるかは謎。

RAFTとは

眠いので省きます。論文の導入を確認してください。

環境構築

Google CoraboratoryではPytorchとCUDAのバージョンが決まっているらしい。
RAFTの指定としては以下

pytorch=1.6.0 
torchvision=0.7.0 
cudatoolkit=10.1

またGoogle Coraboratory特有の制限があるので、そこの設定もしていく。

手順

こちらのサイトの通り設定します。

バージョンは2022/6/3のものになります。
まずはCUDAから設定していきます。

!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0

!python -c "import torch; print(torch.__version__)"

1.11.0+cu113

!ls -d /usr/local/cuda-*

/usr/local/cuda-10.0 /usr/local/cuda-11 /usr/local/cuda-11.1
/usr/local/cuda-10.1 /usr/local/cuda-11.0

import os
p = os.getenv('PATH')
ld = os.getenv('LD_LIBRARY_PATH')
os.environ['PATH'] = f"/usr/local/cuda-10.1/bin:{p}"
os.environ['LD_LIBRARY_PATH'] = f"/usr/local/cuda-10.1/lib64:{ld}"
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243

これでcudaはバージョン10.1になりました。
続いてPytorchも1.6.0にしていきます。

!pip install torch==1.6.0 torchvision==0.7.0 tensorboard

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch==1.6.0
Downloading torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl (748.8 MB)
|████████████████████████████████| 748.8 MB 19 kB/s
Collecting torchvision==0.7.0
Downloading torchvision-0.7.0-cp37-cp37m-manylinux1_x86_64.whl (5.9 MB)
|████████████████████████████████| 5.9 MB 31.6 MB/s
Requirement already satisfied: tensorboard in /usr/local/lib/python3.7/dist-packages (2.8.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch==1.6.0) (1.21.6)
Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from torch==1.6.0) (0.16.0)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from torchvision==0.7.0) (7.1.2)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (0.37.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (0.4.6)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (3.3.7)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (1.35.0)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (57.4.0)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (2.23.0)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (0.6.1)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (1.46.3)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (3.17.3)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (1.8.1)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (1.0.0)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard) (1.0.1)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.4->tensorboard) (1.15.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard) (0.2.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard) (4.2.4)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard) (4.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard) (4.11.4)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard) (3.8.0)
Requirement already satisfied: typing-extensions>=3.6.4 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard) (4.2.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard) (0.4.8)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard) (2022.5.18.1)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard) (3.0.4)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard) (3.2.0)
Installing collected packages: torch, torchvision
Attempting uninstall: torch
Found existing installation: torch 1.11.0+cu113
Uninstalling torch-1.11.0+cu113:
Successfully uninstalled torch-1.11.0+cu113
Attempting uninstall: torchvision
Found existing installation: torchvision 0.12.0+cu113
Uninstalling torchvision-0.12.0+cu113:
Successfully uninstalled torchvision-0.12.0+cu113
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchtext 0.12.0 requires torch==1.11.0, but you have torch 1.6.0 which is incompatible.
torchaudio 0.11.0+cu113 requires torch==1.11.0, but you have torch 1.6.0 which is incompatible.
Successfully installed torch-1.6.0 torchvision-0.7.0

エラーがtorchtexttorchaudioのバージョンがpytorch1.6.0との互換性がないよということで出ています。
今回は特に関係ないのでそのままですが、気になる人は上記と一緒にインストールすると良いと思います。

!python -c "import torch; print(torch.__version__)"

1.6.0

無事Pytorchが1.6.0になったことを確認できました。

続いてRAFTをGithubから取得したいですが、これもGoogle Coraboratoryでできるようです。

こちらもサイトに沿っていきます。

!git clone https://github.com/princeton-vl/RAFT.git

Cloning into 'RAFT'...
remote: Enumerating objects: 144, done.
remote: Total 144 (delta 0), reused 0 (delta 0), pack-reused 144
Receiving objects: 100% (144/144), 10.01 MiB | 30.25 MiB/s, done.
Resolving deltas: 100% (57/57), done.

import os
path = '/content/RAFT'

#作業ディレクトリをpathに移動する
os.chdir(path)

これでRAFTのダウンロードと作業ディレクトリも変更できました。
このままデモスクリプトを動かしてみましょう。

動作確認

https://github.com/princeton-vl/RAFT#demos
Demosに沿っていきます。
まずは既存モデルのダウンロードから

!./download_models.sh

--2022-06-02 15:02:58-- https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/4j4z58wuv8o0mfz/models.zip [following]
--2022-06-02 15:03:01-- https://www.dropbox.com/s/raw/4j4z58wuv8o0mfz/models.zip
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc92a4d40d4bf58266ef7332dc44.dl.dropboxusercontent.com/cd/0/inline/BmfcKBTZTPdWhRCwkEzla2W0-UdYJn4h10uFbZfesQsMoYoR-5xoykAPQePJxSAIHZBFVeSq69X45cGbL_5MBrAon6QhQSGgY-leHe6my7gCyR_0LYxKbECzfGpuQmyg3_jBdXbbvwFKulh4lDV_7xVIlkhShcis29YT2VPN90vbOA/file# [following]
--2022-06-02 15:03:02-- https://uc92a4d40d4bf58266ef7332dc44.dl.dropboxusercontent.com/cd/0/inline/BmfcKBTZTPdWhRCwkEzla2W0-UdYJn4h10uFbZfesQsMoYoR-5xoykAPQePJxSAIHZBFVeSq69X45cGbL_5MBrAon6QhQSGgY-leHe6my7gCyR_0LYxKbECzfGpuQmyg3_jBdXbbvwFKulh4lDV_7xVIlkhShcis29YT2VPN90vbOA/file
Resolving uc92a4d40d4bf58266ef7332dc44.dl.dropboxusercontent.com (uc92a4d40d4bf58266ef7332dc44.dl.dropboxusercontent.com)... 162.125.3.15, 2620:100:601b:15::a27d:80f
Connecting to uc92a4d40d4bf58266ef7332dc44.dl.dropboxusercontent.com (uc92a4d40d4bf58266ef7332dc44.dl.dropboxusercontent.com)|162.125.3.15|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /cd/0/inline2/Bmd4Ab6Xs4plNQxRNLmhz0V5eGxVekEXI_C6tGd7YCdgpp77GtDYiC_Dyzem4mkXUD1eRJSeV5hVKe_MSyXiG1kbwEMKOmZzEuQyI9W_j11j14mND28b-7wUIStGT0YKvpHe3EyCtnnnPN4m2xcKwxsEMjFitbN-tLvNgNoefPGemUgekO5rRtEmhh3po_lc1upUKFa9PwFL7t8jpR4kap5UAq8lLOML0Vmn6MX-taLQ-wprFEVu3ttFsBg3Bl6YVTe_nzM_hrBiL0AsosDsIdos4ll6qoFq6sUKxYCtAoOj2MfLt42ZNbeHodVoUnJ_dARPwO6fFKdPg-uCplP9aTkc13X14QaB2RmY1Y-1snAOzo-z2_3xekFX8vq2ymjpaDIU8op2vfsPe6TkiDvt8VMpHLyzRgtuHyHYMCLDJFLZyA/file [following]
--2022-06-02 15:03:02-- https://uc92a4d40d4bf58266ef7332dc44.dl.dropboxusercontent.com/cd/0/inline2/Bmd4Ab6Xs4plNQxRNLmhz0V5eGxVekEXI_C6tGd7YCdgpp77GtDYiC_Dyzem4mkXUD1eRJSeV5hVKe_MSyXiG1kbwEMKOmZzEuQyI9W_j11j14mND28b-7wUIStGT0YKvpHe3EyCtnnnPN4m2xcKwxsEMjFitbN-tLvNgNoefPGemUgekO5rRtEmhh3po_lc1upUKFa9PwFL7t8jpR4kap5UAq8lLOML0Vmn6MX-taLQ-wprFEVu3ttFsBg3Bl6YVTe_nzM_hrBiL0AsosDsIdos4ll6qoFq6sUKxYCtAoOj2MfLt42ZNbeHodVoUnJ_dARPwO6fFKdPg-uCplP9aTkc13X14QaB2RmY1Y-1snAOzo-z2_3xekFX8vq2ymjpaDIU8op2vfsPe6TkiDvt8VMpHLyzRgtuHyHYMCLDJFLZyA/file
Reusing existing connection to uc92a4d40d4bf58266ef7332dc44.dl.dropboxusercontent.com:443.
HTTP request sent, awaiting response... 200 OK
Length: 81977417 (78M) [application/zip]
Saving to: ‘models.zip’

models.zip 100%[===================>] 78.18M 92.6MB/s in 0.8s

2022-06-02 15:03:03 (92.6 MB/s) - ‘models.zip’ saved [81977417/81977417]

Archive: models.zip
creating: models/
inflating: models/raft-kitti.pth
inflating: models/raft-sintel.pth
inflating: models/raft-chairs.pth
inflating: models/raft-things.pth
inflating: models/raft-small.pth

モデルがダウンロードできました。
このモデルはオプティカルフロー用のデータセットとして用意されているもので学習したものになります。

以下のリンク先にて確認することができます。
https://github.com/princeton-vl/RAFT#required-data

これで動かせるのですが、Google Coraboratoryはcv2.imshowなどが上手く機能しません。
そのため、demo.pyの一部を書き換えてローカルに保存するようにしましょう。

demo.py
import sys
sys.path.append('core')

import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image

from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder

DEVICE = 'cuda'

def load_image(imfile):
    img = np.array(Image.open(imfile)).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None].to(DEVICE)

# 引数を追加
def viz(img, flo, i):
    img = img[0].permute(1,2,0).cpu().numpy()
    flo = flo[0].permute(1,2,0).cpu().numpy()
    
    # map flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    img_flo = np.concatenate([img, flo], axis=0)

    # import matplotlib.pyplot as plt
    # plt.imshow(img_flo / 255.0)
    # plt.show()

    # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
    # cv2.waitKey()
    # /255.0を消したのは真っ黒画像しか出てこないので、一旦外しました。
    cv2.imwrite(f'result/image_{i+1}.png', img_flo[:, :, [2,1,0]])


def demo(args):
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(DEVICE)
    model.eval()

    with torch.no_grad():
        images = glob.glob(os.path.join(args.path, '*.png')) + \
                 glob.glob(os.path.join(args.path, '*.jpg'))
        # 画像を保存するときに順番に保存されるようにenumerateでindexを追加します。
        images = sorted(images)
        for i, (imfile1, imfile2) in enumerate(zip(images[:-1], images[1:])):
            image1 = load_image(imfile1)
            image2 = load_image(imfile2)

            padder = InputPadder(image1.shape)
            image1, image2 = padder.pad(image1, image2)

            flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
            # 引数を追加
            viz(image1, flow_up, i)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', help="restore checkpoint")
    parser.add_argument('--path', help="dataset for evaluation")
    parser.add_argument('--small', action='store_true', help='use small model')
    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
    parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
    args = parser.parse_args()
    # resultというディレクトリを作ってそこに突っ込むことにします。
    os.makedirs("result", exist_ok=True)
    demo(args)

修正後RAFT下にアップしなおして下さい。
以下で実行します。

!python demo.py --model=models/raft-sintel.pth --path=demo-frames

ダウンロードはこれを参考にした。

# pwd, /content/RAFT
# ダウンロードしたいフォルダを zip 圧縮する
!zip -r result.zip result

# 圧縮した zip ファイルをダウンロードする
from google.colab import files
files.download("result.zip")

adding: result/ (stored 0%)
adding: result/image_4.png (deflated 14%)
adding: result/image_8.png (deflated 15%)
adding: result/image_9.png (deflated 15%)
adding: result/image_3.png (deflated 13%)
adding: result/image_1.png (deflated 13%)
adding: result/image_6.png (deflated 13%)
adding: result/image_7.png (deflated 13%)
adding: result/image_5.png (deflated 13%)
adding: result/image_2.png (deflated 13%)

zipを解凍するとresultに以下のような画像が出てくると思います。
image_1.png

おわりに

docker-compose での方法も書こうと思ったのですが、そういえばグラボがねえ
だからこれやってんだ!というのを思い出しました。

とりあえず前作ったdocker-composeの設定だけとりあえず記事として挙げておきます。
→ 書いた。

時間があれば学習スクリプトの修正等の記事も書きます。

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1