7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Masked Autoencodersを比較的新しい環境で動かしてみる

Last updated at Posted at 2023-08-27

概要

Masked Autoencodersというアルゴリズムに興味があり、動かそうとしたら自分の環境では色々と修正が必要なことが判明した。動作させるまでに行った内容をメモとして残す。

背景

公式実装はtimmやpytorchのバージョンが古く、自分のGPU(RTX A4500)では動作させられなかった。せっかくなので比較的最新のバージョンで動くように修正する。

修正後リポジトリ

動作環境

  • OS: Ubuntu22.04
  • GPU: RTX A4500
  • Docker Image: pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
import torch,torchvision,timm

print(torch.__version__)
print(torchvision.__version__)
print(torch.cuda.is_available())
print(timm.__version__)

#2.0.1
#0.15.2
#True
#0.9.5
$ nvidia-smi
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX A4500                On | 00000000:01:00.0 Off |                  Off |
| 30%   40C    P8               28W / 200W|   1597MiB / 20470MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0

変更箇所

timmのバージョンチェックを無効化

main_pretrain.py
-assert timm.__version__ == "0.3.2"  # version check

torch._six.infの使用をやめて、torch.infに変更

util/misc.py.py
-from torch._six import inf
-if norm_type == inf:
+if norm_type == torch.inf:

引数qk_scaleを削除(複数ヶ所)

models_mae.py
-Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
+Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)

np.floatをfloatに変更

util/pos_embed.py
-omega = np.arange(embed_dim // 2, dtype=np.float)
+omega = np.arange(embed_dim // 2, dtype=float)

add_weight_decay()param_groups_weight_decay()に変更

main_pretrain.py
-param_groups = optim_factory.add_weight_decay(model_without_ddp,args.weight_decay)
+param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, weight_decay=args.weight_decay)

実行には影響無いがwarning対策

main_pretrain.py
-from torch.utils.tensorboard import SummaryWriter
+from torch.utils.tensorboard.writer import SummaryWriter

※実行には影響無いがBlack Formatterで一部のファイルはフォーマット済み。

動作検証

概要

今回はCIFAR10データセットを利用して学習し、学習済みモデルを使ってマスク画像が復元できるかを確認する。

変更点

CIFAR10データセットをmaeディレクトリ以下に次のように配置

/workspaces/mae$ tree -L 3
.
|-- data
|   |-- CIFAR10
|   |   |-- train
|   |   `-- val

CIFAR10データセット用にtransformsを更新
※MAEはViTを使用しており入力画像は224x224を前提としているので合わせる。

main_pretrain.py
# simple augmentation
transform_train = transforms.Compose(
    [
        # 画像サイズを32x32から224x224に変更
        transforms.Resize(224),
        # リサイズ後の画像を224x224に変更
        transforms.RandomResizedCrop(224, scale=(0.2, 1.0),interpolation=transforms.InterpolationMode.BICUBIC),  # 3 is bicubic
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # CIFAR10データセット用に平均と標準偏差
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
    ]
)

変更

main_pretrain.py
parser.add_argument("--input_size", default=32, type=int, help="images input size")
parser.add_argument("--data_path", default="./data/CIFAR10", type=str, help="dataset path")

実行

とりあえず100epochで学習(今回の環境では約6時間14分)

python main_pretrain.py --epochs 100

画像1(1epoch目)
output1.png
画像2(40epoch目)
output2.png
画像3(99epoch目)
output3.png

学習が進むごとにマスクされた部分の学習ができていることを確認。

7
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?