0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Pytorchで2クラス分類問題を解く

Last updated at Posted at 2023-07-19

概要

Pytorchで2クラス分類問題を解くプログラムをまとめます。
ChatGPTを使えば一発でそんなプログラムなんて出てくるのになんで今さら とりあえずまとめていきましょう!

環境

  • OS: MacOS 13.4.1
  • DockerDesktop: 4.20.1
  • in container
    • OS: Ubuntu20.04
    • torch: 1.12.1
    • torchvision: 0.13.1

準備

データセット

データはOxford Pets datasetを使います。
以下のコマンドで、任意のディレクトリに保存します。

wget https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz -P data/pets
tar -xzf data/pets/images.tar.gz -C data/pets
rm data/pets/images.tar.gz

これで./data/pets/ディレクトリに37種類の犬または猫の画像がそれぞれ約200枚ずつ保存されます。
ここではカレントディレクトリにこのデータを保存します。

Docker Container

適当なDockerfileでビルドし、コンテナを立ち上げます。
このとき、カレントディレクトリに保存した./data/pets/images/ディレクトリは結合しておきます。

プログラム

データセット

まずはデータを見てみます。
各画像のラベルはファイル名として記録されている形で、例えばAbyssinian_1.jpgのようなファイル名のとき、_で区切られた前半のAbyssinianがクラス名、その後の数字が付番であることがわかります。
そこで以下のようにラベルを抽出します。

import os
filename = 'Abyssinian_1.jpg'
extension = os.path.splitext(filename)[1]
label = filename.replace(extension, '')
# print(label) 
# Abyssinian

さて、ここでは2クラス分類問題を設定したいので、37種類の画像からどれか2つの種類をピックアップしたいと思います。
英語で書かれた動物の種類は何が何やらわからない 犬と猫の見分けのほうが簡単なため、適当に犬をnewfoundlandpomeranian、猫をAbyssinianBombay、それぞれピックアップしたいと思います。
それぞれファイル名とラベルをリストに格納したものを用意します。種類の名前はここでは使いませんが、後々使うようになるのでここでも格納しておきます。

train.py
import os
path_input = os.path.join('dataset')
list_filename = os.listdir(path_input)

list_filenames = os.listdir(path_input)
list_file = []
for filename in list_filenames:
    if ('newfoundland' in filename) or ('pomeranian' in filename):
        label = 0 # dog
    elif ('Abyssinian' in filename) or ('Bombay' in filename):
        label = 1 # cat
    else:
        continue
    list_file.append([filename, label, filename.split('_')[0]])
print(list_file[0])
# ['Abyssinian_1.jpg', 1, 'Abyssinian']

ここでlabelは、犬は0、猫は1とします。
(余談ですが、このデータセットで品種の頭文字が小文字の場合は犬、大文字の場合は猫のようです)

trainvalidationtestの3つにデータを分けておきます。
scikit-learntrain_test_splitをうまく使います。

train.py
list_train, list_val = train_test_split(list_file, shuffle=True, random_state=random_seed, test_size=0.2)
list_val, list_test = train_test_split(list_val, shuffle=True, random_state=random_seed, test_size=0.5)

これで画像のファイル名とラベルの準備は完了です。

Datasetを作る

以下のようにデータセットを用意します。

utils.py
import os
from PIL import Image
import torch.utils.data as data

class MyDataset(data.Dataset):
    def __init__(self, list_file, transform=None, phase='train'):
        self.list_file = list_file
        self.transform = transform
        self.phase = phase

    def __len__(self):
        # ファイル数を返す
        return len(self.list_file)

    def __getitem__(self, index):
        # 画像をPillowsで開く
        path_input = './data/pets/images/'
        path_image = os.path.join(path_input, self.list_file[index][0])
        pil_image = Image.open(path_image)

        # 画像の前処理
        image_transformed = self.transform(pil_image).convert('RGB')

        # ラベルを取得
        label_class = self.list_file[index][1]
        label_type = self.list_file[index][2]
        return image_transformed, label_class

画像の前処理は以下のクラスを用意します。

utils.py
from torchvision import transforms

class ImageTransform():
    def __init__(self, resize, mean, std):
        self.data_transform = transforms.Compose([
            transforms.Resize(resize),
            transforms.CenterCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    def __call__(self, image):
        return self.data_transform(image)

今回はImageNetで学習済のVGG16モデルを使用するので、インスタンスは次のように設定します。

train.py
resize = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transform = ImageTransform(resize, mean, std)

Datasetのインスタンスは次のように設定します。

train.py
# dataset
train_dataset = MyDataset(list_train, path_input, transform=transform, phase='train')
val_dataset = MyDataset(list_val, path_input, transform=transform, phase='val')

DataLoaderを用意する

DataLoaderは以下のように用意します。

train.py
# dataloader
train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

デバイスの選択

デバイスを選択します。GPUがあればここで設定します。

train.py
# デバイスを選択
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ネットワークの選択

VGG16を選びます。

train.py
from torchvision import models
net = models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1')
# 出力層を2つに付け替える
net.classifier[6] = nn.Linear(in_features=4096, out_features=2)

ここで、models.vgg16(pretrained=True)としても動きますが、以下の警告が出るようになったため、weightで設定しました。

/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.

損失関数

以下のように設定します。2クラス分類のため、特にこだわりもなくクロスエントロピー誤差を選択します。

train.py
import torch.nn as nn
# 損失関数の定義
criterion = nn.CrossEntropyLoss()

最適化関数

以下のように設定します。ここでも特にこだわり無くAdamを選択します。
また、ここでは転移学習を選択します。ついでなのでファインチューニングを選択する場合も記載しておきます。

train.py
import torch.optim as optim
lr = 0.001
USE_FINE_TUNING = False
# 最適化手法の選択
if USE_FINE_TUNING:
    # 最適化手法を設定
    optimizer = optim.Adam(net.parameters(), lr=lr)
else:
    # 転移学習
    params_to_update = []
    update_param_names = ['classifier.6.weight', 'classifier.6.bias']
    for name, param in net.named_parameters():
        if name in update_param_names:
            param.requires_grad = True
            params_to_update.append(param)
        else:
            param.requires_grad = False
        optimizer = optim.Adam(params=params_to_update, lr=lr)

結果の保存

ログを取ります。今回はiteration数、経過時間、lossaccuracyの4つを保存するように設定します。

学習

trainvalを実行します。
valはあるイテレーション数で実行するため関数にしてまとめておきましょう。
ここでは返り値として、accuracylossを選択します。
accuracyは楽なのでscikit-learnのAPIを使います。

train.py
from sklearn.metrics import accuracy_score

def validation(net, device, criterion, val_dataloader):
    net.eval()
    total_loss = 0
    Y = []
    preds = []
    with tqdm(total=len(val_dataloader)) as pbar:
        pbar.set_description('validation')
        for inputs, labels in val_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = net(inputs)
            loss = criterion(outputs, labels)
            _, pred = torch.max(outputs, 1) # ラベルの予想
            total_loss += loss.item() * inputs.size(0)
            Y.extend(labels)
            preds.extend(pred)
            pbar.update(1)
    return accuracy_score(y_true=Y, y_pred=preds), total_loss

学習は以下のように行います。
進行状況がわかるようにtqmdで表示します。
また、ここではepochではなくiterationごとに学習の進行具合を確認しています。

train.py
from tqdm import tqdm

count = 0
iteration = 0
total_loss = 0
Y_train = []
pred_train = []
time_trainval_total_start = time.perf_counter()
with tqdm(total=max_itr) as pbar:
    pbar.set_description('training')
    while iteration < max_itr:
        for inputs, labels in train_dataloader:
            if iteration >= max_itr:
                break
            inputs = inputs.to(device)
            labels = labels.to(device)
            if count == 0:
                net.train()
                time_trainval_interval_start = time.perf_counter()
            optimizer.zero_grad()
            with torch.set_grad_enabled(True):
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                _, pred = torch.max(outputs, 1) # ラベルの予測

                # バックプロパゲーション
                loss.backward()
                optimizer.step()

            # カウント
            count += 1
            iteration += 1
            # 損失計算
            total_loss += loss.item() * inputs.size(0)
            # スコア計算用
            Y_train.extend(labels)
            pred_train.extend(pred)

            if count == val_interval:
                time_trainval_interval_end = time.perf_counter()
                time_trainval_interval = time_trainval_interval_end - time_trainval_interval_start
                total_loss = total_loss / val_interval
                # validation
                acc_score, loss_val, Y, preds = validation(net, device=device, criterion=criterion, val_dataloader=val_dataloader)
                # save log
                ## training
                with open(path_save_logfile_train, 'a') as logfile:
                    logfile.write('{},{},{},{}\n'.format(iteration, time_trainval_interval, total_loss, accuracy_score(y_true=Y_train, y_pred=pred_train)))
                ## validation
                with open(path_save_logfile_val, 'a') as logfile:
                    logfile.write('{},{},{},{}\n'.format(iteration, time_trainval_interval, loss_val, acc_score))

                # 結果の描画

                # reset
                count = 0
                Y_train = []
                pred_train = []

            pbar.update(1)

結果の確認

結果を確認します。イテレーションごとにaccuracyが上昇しているのがわかります。
今回はVGG16の学習済モデルを使い、分類する対象も犬と猫の画像だったため、転移学習でもすぐにaccuracy1.0になりました。
もう少し難しいデータセットを採用すると、工夫のしがいがあるかもしれません。
また、結果もグラフで表示できるようにしておくと、計算途中で性能がどれくらいか見積もることができ、便利だと思います。

log_score_val.csv
iteration,time,loss,acc
0,107.61086942399561,584.3627863526344,0.40866035182679294
10,51.51534285600064,0.16899609718075226,1.0
20,45.83238556300057,0.021452696702795038,1.0

Github

ここまでで断片的に紹介したプログラムは以下のリポジトリに載せています。

最後に

今回は2クラスの画像分類をpytorchで書くプログラムを紹介しました。
データセットも猫と犬の画像で分類しやすいものを紹介しました。
accuracyもすぐ1.0になったので、かなり分類しやすい問題だったのだと思います。
今後はこのモデルをベースに、より難しい状況での分類問題を考えていきたいと思います。

参考

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?