概要
Pytorchで2クラス分類問題を解くプログラムをまとめます。
ChatGPTを使えば一発でそんなプログラムなんて出てくるのになんで今さら とりあえずまとめていきましょう!
環境
- OS: MacOS 13.4.1
- DockerDesktop: 4.20.1
- in container
- OS: Ubuntu20.04
- torch: 1.12.1
- torchvision: 0.13.1
準備
データセット
データはOxford Pets dataset
を使います。
以下のコマンドで、任意のディレクトリに保存します。
wget https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz -P data/pets
tar -xzf data/pets/images.tar.gz -C data/pets
rm data/pets/images.tar.gz
これで./data/pets/
ディレクトリに37種類の犬または猫の画像がそれぞれ約200枚ずつ保存されます。
ここではカレントディレクトリにこのデータを保存します。
Docker Container
適当なDockerfileでビルドし、コンテナを立ち上げます。
このとき、カレントディレクトリに保存した./data/pets/images/
ディレクトリは結合しておきます。
プログラム
データセット
まずはデータを見てみます。
各画像のラベルはファイル名として記録されている形で、例えばAbyssinian_1.jpg
のようなファイル名のとき、_
で区切られた前半のAbyssinian
がクラス名、その後の数字が付番であることがわかります。
そこで以下のようにラベルを抽出します。
import os
filename = 'Abyssinian_1.jpg'
extension = os.path.splitext(filename)[1]
label = filename.replace(extension, '')
# print(label)
# Abyssinian
さて、ここでは2クラス分類問題を設定したいので、37種類の画像からどれか2つの種類をピックアップしたいと思います。
英語で書かれた動物の種類は何が何やらわからない 犬と猫の見分けのほうが簡単なため、適当に犬をnewfoundland
とpomeranian
、猫をAbyssinian
とBombay
、それぞれピックアップしたいと思います。
それぞれファイル名とラベルをリストに格納したものを用意します。種類の名前はここでは使いませんが、後々使うようになるのでここでも格納しておきます。
import os
path_input = os.path.join('dataset')
list_filename = os.listdir(path_input)
list_filenames = os.listdir(path_input)
list_file = []
for filename in list_filenames:
if ('newfoundland' in filename) or ('pomeranian' in filename):
label = 0 # dog
elif ('Abyssinian' in filename) or ('Bombay' in filename):
label = 1 # cat
else:
continue
list_file.append([filename, label, filename.split('_')[0]])
print(list_file[0])
# ['Abyssinian_1.jpg', 1, 'Abyssinian']
ここでlabel
は、犬は0
、猫は1
とします。
(余談ですが、このデータセットで品種の頭文字が小文字の場合は犬、大文字の場合は猫のようです)
train
とvalidation
、test
の3つにデータを分けておきます。
scikit-learn
のtrain_test_split
をうまく使います。
list_train, list_val = train_test_split(list_file, shuffle=True, random_state=random_seed, test_size=0.2)
list_val, list_test = train_test_split(list_val, shuffle=True, random_state=random_seed, test_size=0.5)
これで画像のファイル名とラベルの準備は完了です。
Dataset
を作る
以下のようにデータセットを用意します。
import os
from PIL import Image
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self, list_file, transform=None, phase='train'):
self.list_file = list_file
self.transform = transform
self.phase = phase
def __len__(self):
# ファイル数を返す
return len(self.list_file)
def __getitem__(self, index):
# 画像をPillowsで開く
path_input = './data/pets/images/'
path_image = os.path.join(path_input, self.list_file[index][0])
pil_image = Image.open(path_image)
# 画像の前処理
image_transformed = self.transform(pil_image).convert('RGB')
# ラベルを取得
label_class = self.list_file[index][1]
label_type = self.list_file[index][2]
return image_transformed, label_class
画像の前処理は以下のクラスを用意します。
from torchvision import transforms
class ImageTransform():
def __init__(self, resize, mean, std):
self.data_transform = transforms.Compose([
transforms.Resize(resize),
transforms.CenterCrop(resize),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
def __call__(self, image):
return self.data_transform(image)
今回はImageNet
で学習済のVGG16
モデルを使用するので、インスタンスは次のように設定します。
resize = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transform = ImageTransform(resize, mean, std)
Dataset
のインスタンスは次のように設定します。
# dataset
train_dataset = MyDataset(list_train, path_input, transform=transform, phase='train')
val_dataset = MyDataset(list_val, path_input, transform=transform, phase='val')
DataLoader
を用意する
DataLoader
は以下のように用意します。
# dataloader
train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
デバイスの選択
デバイスを選択します。GPUがあればここで設定します。
# デバイスを選択
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ネットワークの選択
VGG16
を選びます。
from torchvision import models
net = models.vgg16(weights='VGG16_Weights.IMAGENET1K_V1')
# 出力層を2つに付け替える
net.classifier[6] = nn.Linear(in_features=4096, out_features=2)
ここで、models.vgg16(pretrained=True)
としても動きますが、以下の警告が出るようになったため、weight
で設定しました。
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
損失関数
以下のように設定します。2クラス分類のため、特にこだわりもなくクロスエントロピー誤差を選択します。
import torch.nn as nn
# 損失関数の定義
criterion = nn.CrossEntropyLoss()
最適化関数
以下のように設定します。ここでも特にこだわり無くAdam
を選択します。
また、ここでは転移学習を選択します。ついでなのでファインチューニングを選択する場合も記載しておきます。
import torch.optim as optim
lr = 0.001
USE_FINE_TUNING = False
# 最適化手法の選択
if USE_FINE_TUNING:
# 最適化手法を設定
optimizer = optim.Adam(net.parameters(), lr=lr)
else:
# 転移学習
params_to_update = []
update_param_names = ['classifier.6.weight', 'classifier.6.bias']
for name, param in net.named_parameters():
if name in update_param_names:
param.requires_grad = True
params_to_update.append(param)
else:
param.requires_grad = False
optimizer = optim.Adam(params=params_to_update, lr=lr)
結果の保存
ログを取ります。今回はiteration
数、経過時間、loss
、accuracy
の4つを保存するように設定します。
学習
train
とval
を実行します。
val
はあるイテレーション数で実行するため関数にしてまとめておきましょう。
ここでは返り値として、accuracy
とloss
を選択します。
accuracy
は楽なのでscikit-learn
のAPIを使います。
from sklearn.metrics import accuracy_score
def validation(net, device, criterion, val_dataloader):
net.eval()
total_loss = 0
Y = []
preds = []
with tqdm(total=len(val_dataloader)) as pbar:
pbar.set_description('validation')
for inputs, labels in val_dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
outputs = net(inputs)
loss = criterion(outputs, labels)
_, pred = torch.max(outputs, 1) # ラベルの予想
total_loss += loss.item() * inputs.size(0)
Y.extend(labels)
preds.extend(pred)
pbar.update(1)
return accuracy_score(y_true=Y, y_pred=preds), total_loss
学習は以下のように行います。
進行状況がわかるようにtqmd
で表示します。
また、ここではepoch
ではなくiteration
ごとに学習の進行具合を確認しています。
from tqdm import tqdm
count = 0
iteration = 0
total_loss = 0
Y_train = []
pred_train = []
time_trainval_total_start = time.perf_counter()
with tqdm(total=max_itr) as pbar:
pbar.set_description('training')
while iteration < max_itr:
for inputs, labels in train_dataloader:
if iteration >= max_itr:
break
inputs = inputs.to(device)
labels = labels.to(device)
if count == 0:
net.train()
time_trainval_interval_start = time.perf_counter()
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = net(inputs)
loss = criterion(outputs, labels)
_, pred = torch.max(outputs, 1) # ラベルの予測
# バックプロパゲーション
loss.backward()
optimizer.step()
# カウント
count += 1
iteration += 1
# 損失計算
total_loss += loss.item() * inputs.size(0)
# スコア計算用
Y_train.extend(labels)
pred_train.extend(pred)
if count == val_interval:
time_trainval_interval_end = time.perf_counter()
time_trainval_interval = time_trainval_interval_end - time_trainval_interval_start
total_loss = total_loss / val_interval
# validation
acc_score, loss_val, Y, preds = validation(net, device=device, criterion=criterion, val_dataloader=val_dataloader)
# save log
## training
with open(path_save_logfile_train, 'a') as logfile:
logfile.write('{},{},{},{}\n'.format(iteration, time_trainval_interval, total_loss, accuracy_score(y_true=Y_train, y_pred=pred_train)))
## validation
with open(path_save_logfile_val, 'a') as logfile:
logfile.write('{},{},{},{}\n'.format(iteration, time_trainval_interval, loss_val, acc_score))
# 結果の描画
# reset
count = 0
Y_train = []
pred_train = []
pbar.update(1)
結果の確認
結果を確認します。イテレーションごとにaccuracy
が上昇しているのがわかります。
今回はVGG16
の学習済モデルを使い、分類する対象も犬と猫の画像だったため、転移学習でもすぐにaccuracy
が1.0
になりました。
もう少し難しいデータセットを採用すると、工夫のしがいがあるかもしれません。
また、結果もグラフで表示できるようにしておくと、計算途中で性能がどれくらいか見積もることができ、便利だと思います。
iteration,time,loss,acc
0,107.61086942399561,584.3627863526344,0.40866035182679294
10,51.51534285600064,0.16899609718075226,1.0
20,45.83238556300057,0.021452696702795038,1.0
Github
ここまでで断片的に紹介したプログラムは以下のリポジトリに載せています。
最後に
今回は2クラスの画像分類をpytorch
で書くプログラムを紹介しました。
データセットも猫と犬の画像で分類しやすいものを紹介しました。
accuracy
もすぐ1.0
になったので、かなり分類しやすい問題だったのだと思います。
今後はこのモデルをベースに、より難しい状況での分類問題を考えていきたいと思います。
参考