3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Pytorchの転移学習でフレンチブルドッグとブルドッグを見分ける

Last updated at Posted at 2019-06-27

はじめに

実家で飼っているフレンチブルドッグを散歩をしていると、おばちゃんに「あら、可愛いブルドッグね」とよく声をかけられます。
そこで、機械学習を使って、フレンチブルドッグとブルドッグを分類するモデルをつくろうと思います。

手順

モデルの構築は、以下の手順で行いました。

  • データの用意
  • モデルの学習(Jupyter Notebook)

データの用意〜モデルの学習に際しては、ABEJA Platformを使いました。

データの用意

データは、Stanford Dogs Datasetを使います。
これは120の犬種カテゴリ、20,580枚の画像から成るデータセットです。

今回は、主にフレンチブルドッグとブルドックの分類に関心があるため、120犬種の中から、ブル系の犬種のデータのみを選定して使用します。

id 0: ボストンブル
id 1: ブルマスティフ
id 2: フレンチブルドッグ
id 3: スタッフォードシャーブルテリア

データセットの用意はこちらの手順に詳細を記載しています。
https://qiita.com/yushin_n/items/98fcc788710c0ace3a4a

モデルの学習(Jupyter Notebook)

モデルの学習はABEJA PlatformのJupyter Notebook(インスタンスタイプ:gpu-1)で実施します。

ライブラリをインポートします。


import io
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from PIL import Image
from tqdm import tqdm

from abeja.datasets import Client as DatasetClient

# check if CUDA is available
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

Datasetsをロードします。


# import data from ABEJA Platform Dataset
def load_dataset_from_api(dataset):
    for item in tqdm(dataset.dataset_items.list(prefetch=True)):
        try:
            file_content = item.source_data[0].get_content()
            file_like_object = io.BytesIO(file_content)
            img = Image.open(file_like_object).convert('RGB')
            label_id = item.attributes['classification'][0]['label_id']
            yield img, label_id
        except OSError:
            print('Fail to load dataset_item_id {}.'.format(item.dataset_item_id))

# define dataset id
dataset_id = 'XXXXXXXXXXXXX'

# get dataset via ABEJA Platform api
dataset_client = DatasetClient()
dataset = dataset_client.get_dataset(dataset_id)
dataset_list = list(load_dataset_from_api(dataset))
num_classes = len(dataset.props['categories'][0]['labels'])
print('number of classes is {}.'.format(num_classes))

データを1枚表示してみます。


import matplotlib.pyplot as plt                        
%matplotlib inline
# show one of images
for item in dataset_list:
    plt.imshow(item[0])
    plt.title('label id is {}'.format(item[1]))
    plt.show()
    break

きちんとデータがロードされていることが確認できました。
download-1.png

続いて、PytorchのDatasetをカスタマイズします。


# define customized dataset class
class CustomizedDataset(Dataset):

    def __init__(self, dataset_list, transform=None):
        self.transform = transform
        self.dataset_list = dataset_list
        self.img = [item[0] for item in self.dataset_list]
        self.label_id = [item[1] for item in self.dataset_list]

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        out_img = self.img[idx]
        out_label_id = self.label_id[idx]

        if self.transform:
            out_img = self.transform(out_img)

        return out_img, out_label_id

データの前処理を行うTransformを設定します。

Pytorchのドキュメントの学習済みモデルの作法に沿って、Shape = [3 x 224 x 224]にクロップ&リサイズして、mean =[0.485, 0.456, 0.406]、std = [0.229, 0.224, 0.225]を用いて標準化を行いました。また、Training用データは、回転、フリップの処理を加えました。


# convert data to a normalized torch.FloatTensor
train_transform = transforms.Compose([transforms.Resize(size=256),
                  transforms.CenterCrop((224, 224)),
                  transforms.RandomHorizontalFlip(),  # randomly flip and rotate
                  transforms.RandomRotation(10),
                  transforms.ToTensor(),
                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

valid_transform = transforms.Compose([transforms.Resize(size=256),
                  transforms.CenterCrop((224, 224)),
                  transforms.ToTensor(),
                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

Transformを用いて、Traning用とValidation用のDataLoaderを作成します。

valid_size = 0.2
num_workers = 0
batch_size = 12

# prepare dataloader
def load_split_train_valid(dataset_list):
    """ Split dataset into training and validation set """

    train_data = CustomizedDataset(dataset_list, transform=train_transform)
    valid_data = CustomizedDataset(dataset_list, transform=valid_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))
    np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    trainloader = DataLoader(train_data, sampler=train_sampler,
                             batch_size=batch_size, num_workers=num_workers)
    validloader = DataLoader(valid_data, sampler=valid_sampler,
                             batch_size=batch_size, num_workers=num_workers)

    print('number of train datasets is {}.'.format(len(trainloader) * batch_size))
    print('number of valid datasets is {}.'.format(len(validloader) * batch_size))

    return trainloader, validloader

# create dataloader
trainloader, validloader = load_split_train_valid(dataset_list)

続いて、学習です。
今回は学習済みのResNet-50をロードして、畳み込み層をフリーズし最後の全結合層のみを学習させます。


# set hyperparameters
n_epochs = 10
learning_rate = 0.001

# set save path
save_path = './model.pt'

# specify model architecture (ResNet-50)
model = models.resnet50(pretrained=True)

# freeze parameters so we don't backprop through them
for param in model.parameters():
    param.requires_grad = False

# replace the last fully connected layer with a Linnear layer with no. of classes out features
model.fc = nn.Linear(2048, num_classes)
model = model.to(device)

# define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)


def train_model(trainloader, validloader, model, optimizer, criterion):
    """returns trained model"""
    # initialize tracker for minimum validation loss
    valid_loss_min = 3.877533  # np.Inf

    if os.path.exists(save_path):
        model.load_state_dict(torch.load(save_path))

    for epoch in range(n_epochs):
        # initialize variables to monitor training and validation loss and accuracy
        train_loss = 0.0
        train_total = 0
        train_correct = 0
        valid_loss = 0.0
        valid_total = 0
        valid_correct = 0

        # train the model
        model.train()
        for data, target in trainloader:
            data, target = data.to(device), target.to(device)

            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # update training loss
            train_loss += loss.item() * data.size(0)
            # count number of correct labels
            _, preds_tensor = torch.max(output, 1)
            train_total += target.size(0)
            train_correct += (preds_tensor == target).sum().item()

        # validate the model
        model.eval()
        for data, target in validloader:
            data, target = data.to(device), target.to(device)

            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # update average validation loss
            valid_loss += loss.item() * data.size(0)
            # count number of correct labels
            _, preds_tensor = torch.max(output, 1)
            valid_total += target.size(0)
            valid_correct += (preds_tensor == target).sum().item()

        # calculate average losses
        train_loss /= len(trainloader.dataset)
        valid_loss /= len(validloader.dataset)
        # calculate accuracy
        train_acc = train_correct / train_total
        valid_acc = valid_correct / valid_total

        # print training/validation statistics
        print('Epoch: {} \tTrain loss: {:.6f} \tTrain acc: {:.6f} \tValid loss: {:.6f} \tValid acc: {:.6f}'.format(
                epoch + 1,
                train_loss,
                train_acc,
                valid_loss,
                valid_acc))

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model.'.format(
                valid_loss_min,
                valid_loss))
            torch.save(model.state_dict(), save_path)
            valid_loss_min = valid_loss
    # return trained model
    return model

# train the model
train_model(trainloader, validloader, model, optimizer, criterion)

10エポック学習して、Validation Accuracyが95.38%となりました。なかなかの精度でフレンチブルドッグとブルドッグを見分けることができているようです。

最後に、自分の持っている画像データでテストをしてみます。先ほど保存したモデルのパラメータをロードします。

# prediction
import os
import numpy as np

import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.models

num_classes = 4

# load the model for CPU
device = torch.device('cpu')
model = torchvision.models.resnet50(pretrained=True)

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(2048, num_classes)
model.load_state_dict(torch.load(save_path, map_location=device))

テスト用の画像を前処理して、推論を行います。


def image_to_tensor(img):
    img = img.convert('RGB')
    transformations = transforms.Compose([transforms.Resize(size=256),
                      transforms.CenterCrop((224, 224)),
                      transforms.ToTensor(),
                      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                           std=[0.229, 0.224, 0.225])])
    image_tensor = transformations(img)[:3, :, :].unsqueeze(0)
    return image_tensor


def decode_predictions(result):
    categories = {
        0: 'BOSTON_BULL',
        1: 'BULL_MASTIFF',
        2: 'FRENCH_BULLDOG',
        3: 'STAFFORDSHIRE_BULLTERRIER'
    }
    result_with_labels = [{"label": categories[i], "probability": score} for i, score in enumerate(result)]
    return sorted(result_with_labels, key=lambda x: x['probability'], reverse=True)


def predict(img):
    image_tensor = image_to_tensor(img)
    image_tensor = image_tensor.to(device)
    model.eval()
    output = model(image_tensor)
    # convert output probabilities to predicted class
    softmax = nn.Softmax(dim=1)
    preds_tensor = softmax(output)
    result = np.squeeze(preds_tensor.to(device).detach().numpy())
    sorted_result = decode_predictions(result.tolist())
    return {"result": sorted_result}

フレンチブルドッグの写真を推論にかけてみます。

img = Image.open('./test_french_bull.jpg')
predicted = predict(img)
plt.title("label:{}  prob:{:.2%}".format(predicted['result'][0]['label'],
                                         predicted['result'][0]['probability']))
plt.imshow(img)
plt.show()

フレンチブルドッグと予測しています。
download.png

ブルドッグの写真の場合、ブルマスティフと予測します。

download-1.png

まとめ

本記事では、Pytorchの転移学習でフレンチブルドッグとブルドッグを分類するモデルをつくる手順をまとめました。結果、90%以上の正確性で、フレンチブルドッグとブルドッグを見分けることができました!これで、おばちゃんにも自信を持って「フレンチブルドッグです」と言えますね :dog:

参考

サンプルコード(Jupyter Notebook)
https://github.com/abeja-inc/Platform_handson/blob/master/bulldog_classification/notebook/Bulldog_classification.ipynb

今回モデルの構築に使用した、ABEJA Platformはトライアルも提供しています。気になられた方は、是非、お気軽にお問い合わせください。また、フォーラムもありますので、是非、ご活用ください。

ABEJA Platformに関するお問い合わせ
https://abejainc.com/platform/ja/contact/

ABEJA Platform Forum
https://forums.abeja.io/

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?