1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

知ってるようで知らない画像分類の使い方

画像分類は基本タスクな感じがするけど、じゃあどうすればいいのかとなるとどうするんだっけという感じもする。
state of artの先端アーキテクチャも色々あるんだろうけど、torchvisionのモデルでけっこう精度でるし、自分はけっこう十分です。

推論の仕方

利用可能なモデルとweightsは以下を参照。

モデルの初期化

import torchvision.models as models

weights = models.ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1
model = models.vit_l_16(weights=weights)

画像の前処理

from PIL import Image

transforms = weights.transforms()

im = Image.open("cat.jpg")
input = transforms(im)

クラスの読み込み

ImageNet1kのクラス名です。

classes = weights.meta["categories"]

推論

import torch

with torch.no_grad():
  outputs = model(im.unsqueeze(dim=0).to("cpu"))
  outputs = torch.nn.Softmax(dim=1)(outputs)
  conf, pred = torch.max(outputs, 1)
  print(classes[int(pred)])
  print(float(conf[0]))

この画像で推論します。
cat.jpg

tabby, tabby cat
0.5702

速度比較

Colabの一番安いT4GPUでテスト。

model ImageNetTop1 Acc sec
ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1 88.1 13.3
ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 85.3 2.2
Swin_S_Weights.IMAGENET1K_V1 83.2 0.4
ConvNeXt_Small_Weights.IMAGENET1K_V1 83.6 0.3

学習

必要なライブラリのインポート

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

cudnn.benchmark = True
plt.ion()   # interactive mode

データの準備

以下のディレクトリ構造でクラスごとにフォルダ分けした画像をtrain、valそれぞれ用意

my_dataset
  |
  |__train
  |     |__class0
  |     |     |__****.jpg
  |     |     |__****.jpg
  |     |     |
  |     |     
  |     |__class1
  |     |
  |     
  |__val
        |__class0
        |     |__****.jpg
        |     |__****.jpg
        |     |
        |     
        |__class1
        |

データをtorchvisionのdataset形式にする。
リサイズのサイズは、torchvisionの各モデルのページを参照。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([256,256]),
        # transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize([256,256]),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'my_dataset'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

表示してデータを確認。クラス名も取得。

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

train関数

関数内でweightsチェックポイントの保存先を指定。

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            save_path = "save_dir/my_model"+str(epoch)+".pt"
            torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}, save_path)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

事前トレーニング済みモデルとトレーニングシステムの初期化

num_ftrs = model.head.in_features
model.head = nn.Linear(num_ftrs, len(class_names))

model = model.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9,weight_decay=0.02)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

トレーニング開始

model_ft = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=1000)

トレーニングしたweightsで推論

import torch
from torchvision import models, transforms
from PIL import Image

model = models.swin_v2_s(weights=None)
num_ftrs = model.head.in_features
model.head = torch.nn.Linear(num_ftrs, len(class_names))
if torch.cuda.is_available():
  checkpoint = torch.load(weights_path)
  device = 'cuda'
else:
  checkpoint = torch.load(weights_path,map_location=torch.device('cpu'))
  device = 'cpu'

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model = model.to(device)
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?