LoginSignup
3

More than 1 year has passed since last update.

PyTorchでサクッとオリジナル画像分類モデルを作る【TIMM / Pytorch Image Models】

Last updated at Posted at 2021-11-29

はじめに

kaggleNishikaなどのデータ分析コンペティションでは画像分類を行うことがしばしばあります。
難しいことを考えずにそれなりの精度を出すための画像分類モデルをサクッと作る方法を紹介します。

環境

  • Ubuntu 20.04
  • NVIDIA GeForce RTX 3090
  • CUDA Version: 11.4
  • Docker version 20.10.7, build 20.10.7-0ubuntu1~20.04.2

Dockerイメージのダウンロード

Docker Hubからご自身のCUDAのバージョンに合わせてダウンロードしてください。

CUDAのバージョン確認方法

$ nvidia-smi
Mon Oct 18 11:10:24 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 470.63.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
| 30%   35C    P8    24W / 350W |     20MiB / 24268MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

右上のCUDA Version: 11.4を確認してください。

PyTorchイメージのダウンロード

2021年10月時点ではDocker Hub上にあがっているPyTorchイメージのうち、対応しているCUDAの最新バージョンは11.1となっています。
筆者の環境ではローカルにインストールされているCUDAのバージョンとDockerイメージのCUDAのバージョンが異なっていても問題なく動作しました。

ダウンロードするイメージは末尾がdevelのものを選択してください。
pullコマンドをコピーし、ターミナルにペーストして実行します。

$ docker pull pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel

Dockerコンテナの起動

次のコマンドを実行するとコンテナの起動&ログインができます。
pytorch/pytorch:1.9.0-cuda11.1-cudnn8-develの部分は先ほどpullしてきたイメージに合わせてください。

$ docker run -it --gpus all pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel /bin/bash

TIMMのリポジトリをclone

pullしてきたDockerイメージにはgitが入っていないため、aptでgitを入れた後にcloneします。

$ apt update
$ apt install -y git
$ git clone https://github.com/rwightman/pytorch-image-models.git

cloneできたらそのディレクトリに移動します。

$ cd pytorch-image-models

TIMMについて

公式リポジトリでは、TIMMは次のように説明されています。

PyTorch Image Models (timm) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.

つまり、SOTA (State of the Art) を達成した画像分類モデルを再現するために必要なものが詰まったリポジトリといえるでしょう。

2021年10月時点では次のモデルが公開されています。(一部抜粋)

  • Xception
  • Vision Transformer
  • VGG
  • ResNet
  • MobileNet
  • Big Transfer ResNetV2 (BiT)
  • EfficientNet

学習から推論まで

画像を用意する

分類をしたい画像を用意します。例えば猫と犬の2種類の画像分類モデルを作成したい場合、ディレクトリ構成としては次のようにtrainの直下にcatディレクトリとdogディレクトリを作成し、その中に画像を放り込みます。

train
├── cat
│   ├── cat1.jpg
│   ├── cat2.jpg
│   ├── cat3.jpg
│   ├── cat4.jpg
│   ├── cat5.jpg
│   ├── cat6.jpg
│          ・
│          ・
│          ・
└── dog
    ├── dog1.jpg
    ├── dog2.jpg
    ├── dog3.jpg
    ├── dog4.jpg
    ├── dog5.jpg
    ├── dog6.jpg
           ・
           ・
           ・

画像を用意するのが面倒という方はこちらのリポジトリをクローンしてください。

# Dockerコンテナ内で
$ git clone https://github.com/tsmiyamoto/image_classification_tutorial.git

学習を開始する

デフォルトでは300epochs学習を行います。

訓練
$ python3 train.py image_classification_tutorial/train --pretrained --model tf_efficientnetv2_s_in21ft1k -b 16

推論する

推論はバッチ処理が行われます。

推論
python3 inference.py image_classification_tutorial/test --model tf_efficientnetv2_s_in21ft1k --checkpoint output/train/2021xxxx-xxxxxx-tf_efficientnetv2_s_in21ft1k-300/model_best.pth.tar

処理が完了すると、デフォルトではコマンドを実行したディレクトリ内(上に従っていれば /workspace/pytorch-image-models)にtopk_ids.csvというファイルが生成されます。

topk_ids.csv
cat1.jpg,0,1,853,392,124
cat2.jpg,0,1,629,432,921
cat3.jpg,0,1,274,921,262
cat4.jpg,0,1,123,653,923
cat5.jpg,0,1,720,911,511
dog1.jpg,1,0,931,293,323
dog2.jpg,1,0,833,842,813
dog3.jpg,1,0,782,222,784
dog4.jpg,1,0,783,169,989
dog5.jpg,1,0,736,628,892

出力の意味

この出力は、左から

ファイル名|推論結果(最も確率が高いクラス)|推論結果(2番目に確率が高いクラス)|3番目|4番目|5番目

となっています。
今回は学習させたのが2クラスに対して、デフォルトはトップ5クラスの推論結果を出力するため3番目以降はおかしな数字になっていますが気にしなくて大丈夫です。

推論コマンド実行時に--topk 2とすることでトップ2クラスだけを出力するように変更できます。(参考:inference.py

クラス名と数字の対応関係

学習時のディレクトリ名の昇順(A → Z)に数字が割り当てられます。
サンプルだと以下の通りです。

クラス(ディレクトリ)名 クラスNo.
cat 0
dog 1

カスタマイズする

変更可能なモデルやパラメータの中で、よく使われるものを抜粋して使用法を記述します。

モデルを変更する

TIMMには様々なモデルが用意されていますが、すべてのモデルに事前学習済みモデルがあるわけではないようです。
データが少ない場合には基本的に事前学習済みモデルを使いましょう。

使用方法

train.pyの引数--modelで設定可能です。

$ python3 train.py (クラス別に画像が入ったディレクトリ) --pretrained --model (モデル名)

以下、各モデルで「コマンドライン引数に指定できるもの」に書いてある文字列を(モデル名)のところに当てはめることでモデルを変更することが可能です。(かっこは不要)

事前学習済みモデルがあるもの(一部抜粋)

EfficientNetV2

先で使用したものです。(わかりやすい解説はこちら
2019年にSoTAを達成したEfficientNetのバージョン2です。パラメータ数が少なく、学習が速いという特徴があります。
特段の理由がない限りV2を使っておけばよいでしょう。

EfficientNetV2は2.1万枚の画像で訓練を行い、1000枚の画像でファインチューニングを行ったモデルが用意されているため非常に安定しているイメージです。

コマンドライン引数に指定できるもの
  • tf_efficientnetv2_s_in21ft1k
  • tf_efficientnetv2_m_in21ft1k
  • tf_efficientnetv2_l_in21ft1k
  • tf_efficientnetv2_xl_in21ft1k
    • Validationがきちんと行われていないのでとてもセンシティブでロバストネスがないと説明されています

sよりm、mよりlの方が精度は上がりやすいと思いますが、その分パラメータ数も多くなり、学習に時間がかかります。
また、バッチサイズ(変更方法は後述)をかなり小さくしないとGPUのメモリに載らないと思います。

Vision Transformer

NLPのデファクトスタンダードとなっているモデル「Transformer」を画像に応用するとSoTA同等の性能が達成できたとして2020年に話題となったモデルです。(わかりやすい解説はこちら

コマンドライン引数に指定できるもの
  • vit_tiny_patch16_224_in21k
  • vit_small_patch32_224_in21k
  • vit_small_patch16_224_in21k
  • vit_base_patch32_224_in21k
  • vit_base_patch16_224_in21k
  • vit_base_patch8_224_in21k

Patchの説明は先に挙げた解説記事にあります。この差が精度にどう影響を与えるのかはよくわかっていません。

バッチサイズを変更する

機械学習においては訓練データをいくつかのデータセットに分けます。この、1つのデータセットの大きさのことをバッチサイズと呼びます。

デフォルトは128となっていますが、GPUのメモリ容量不足の時にはバッチサイズを小さくすることで対処することがあります。

一般的には2のべき乗を指定します。ただ、小さすぎると学習に悪影響を与えるので、数百件のデータの場合は16, 32, 64あたりを指定すればよいと思います。

使用方法

train.pyの引数-bまたは--batch-sizeで設定可能です。

$ python3 train.py (クラス別に画像が入ったディレクトリ) -b 32 --pretrained --model (モデル名) 

終わりに

よくあるエラーであったり、他の引数についても追記していきたいと思います。
不明な点はコメントしていただければと思います。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3