1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Vision Transformerを使って顔画像から年齢予測(PyTorch)

Last updated at Posted at 2023-07-11

概要

Vision Transformerの学習済みモデルをFine-tuningして、人の顔画像から年齢を予測する回帰モデルを構築しました。

データは以下からダウンロードしました。

import os
import numpy as np
import matplotlib.pyplot as plt
import PIL
import csv
import torch
import torchvision
from __future__ import print_function
import glob
import random
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from pathlib import Path
import copy
from tqdm import tqdm

データ準備

Zipファイルを適当な場所で展開
ファイル名から年齢を取得
それぞれの画像ファイルのパスと年齢をセットにしてcsvファイルに出力する

画像が含まれているフォルダを読み込んで、画像ファイル名をいくつか表示すると以下のようになります。

file = './archive/crop_part1' 
folderfile = os.listdir(file)
print(folderfile[0:5])

['47_0_0_20170109010047327.jpg.chip.jpg', '4_0_0_20170110211454141.jpg.chip.jpg', '1_0_2_20161219204653181.jpg.chip.jpg', '52_0_0_20170105173134925.jpg.chip.jpg', '75_1_0_20170110140824258.jpg.chip.jpg']

画像ファイルへのパスと年齢がセットになったcsvファイルを作成

csvfile = []
for i in range(len(folderfile)):
    if folderfile[i][1]=='_':      #0~9歳のファイル
        path = folderfile[i]
        path = './archive/crop_part1/' + path
        label = folderfile[i][0]
        listi = [path,label]       #画像のパスと年齢をセットにする
        csvfile.append(listi)
    elif folderfile[i][2]=='_':    #10~99歳のファイル
        path = folderfile[i]
        path = './archive/crop_part1/' + path
        label = folderfile[i][0:2]
        listi = [path,label]
        csvfile.append(listi)
    else:                          #100歳以上のファイル
        path = folderfile[i]
        path = './archive/crop_part1/' + path
        label = folderfile[i][0:3]
        listi = [path,label]
        csvfile.append(listi)

with open('./age1.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerows(csvfile)
f.close()

データセット作成

データセット作成用のクラス定義

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, label_path, transform=None):
        x = []
        y = []
        file = open(label_path, 'r')
        data = csv.reader(file)
        for row in data:
            x.append(row[0])
            y.append(float(row[1]))
        file.close()
    
        self.x = x    
        self.y = torch.from_numpy(np.array(y)).float().view(-1, 1)
     
        self.transform = transform
  
  
    def __len__(self):
        return len(self.x)
  
  
    def __getitem__(self, i):
        img = PIL.Image.open(self.x[i]).convert('RGB')
        if self.transform is not None:
              img = self.transform(img)
    
        return img, self.y[i]

画像の前処理用のtransform

transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

先ほどのcsvファイルを読み込みデータセット作成

data_dir = './age1.csv'
dataset = MyDataset(data_dir, transform=transform)


indices = np.arange(len(dataset))

train_dataset = torch.utils.data.Subset(dataset, indices[0:8000])   #学習用データ
val_dataset = torch.utils.data.Subset(dataset, indices[8000:9500])  #検証用データ
test_dataset = torch.utils.data.Subset(dataset, indices[9500:])     #テストデータ

print(f"full: {len(dataset)} -> train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")

full: 9780 -> train: 8000, val: 1500, test: 280

dataloader作成

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

モデル構築

モデルの読み込み

ここではvit_b_32というVision Transformerの学習済みモデルを読み込んでいます。

import torchvision.models as models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = models.vit_b_32(pretrained=True)

model.heads[0] = nn.Linear(768, 1)  #回帰タスクなので出力を1にする
model=model.to(device)
print(device)
# パラメータ設定
epochs = 10      #学習回数
lr = 3e-5        #学習率
seed = 42

損失関数と最適化関数の指定

今回は回帰タスクなので損失関数に平均絶対誤差(MAE)を指定します。

#損失関数と最適化関数
criterion = torch.nn.L1Loss()    #平均絶対誤差(MAE)
optimizer = optim.Adam(model.parameters(), lr=lr)

モデルの学習と検証

best_loss = None

train_loss_list = []
val_loss_list = []

for epoch in range(epochs):
    epoch_loss = 0
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss / len(train_loader)              

    with torch.no_grad():
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - val_loss : {epoch_val_loss:.4f}\n"
    )

    train_loss_list.append(epoch_loss.cpu().detach().numpy())
    val_loss_list.append(epoch_val_loss.cpu().detach().numpy())

    print()

スクリーンショット 2023-07-12 1.55.51.png
このような感じで学習が進んでいきます。

学習結果の可視化

plt.plot(train_loss_list, label='train')
plt.plot(val_loss_list, label='valid')
plt.legend()
plt.show()

スクリーンショット 2023-07-12 2.03.12.png

テスト

実際に構築したモデルを使ってテストデータに対して年齢予測を行います。

test_loss=[]
running_test_loss = 0.0
pred=[]
ans=[]
with torch.set_grad_enabled(False):
    for data in testloader:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        list1 = labels.tolist()
        outputs = model(inputs)
        list2 = outputs.tolist()
        for i in range(len(list1)):
            ans.append(list1[i])
        for i in range(len(list2)):
            pred.append(list2[i])
        loss = criterion(outputs, labels)
        running_test_loss += loss.item()

test_loss.append(running_test_loss / len(testloader))

print('test loss: {}'.format(running_test_loss / len(testloader)))

test loss: 5.685766831040382

#予測値と正解値の取得
import collections


def flatten(l):   #リストの1次元化
    for el in l:
        if isinstance(el, collections.abc.Iterable) and not isinstance(el, (str, bytes)):
            yield from flatten(el)
        else:
            yield el

ans= list(flatten(ans))
pred= list(flatten(pred))
x = []
y = []
file = open('./age1.csv', 'r')
data = csv.reader(file)
for row in data:
    x.append(row[0])
    y.append(float(row[1]))
file.close()

imagelist=[]
labellist=[]
for i in range(40):
    imagelist.append(PIL.Image.open(x[9500+i]))
    labellist.append(ans)

予測結果表示

import numpy as np
from PIL import Image
import matplotlib.pyplot  as plt
import japanize_matplotlib
fig = plt.figure(figsize=(10,6))
for i, im in enumerate(imagelist):
    
fig.add_subplot(4,10,i+1).set_title('{}\n{}'.format(int(pred[i]),int(ans[i])))
    plt.axis('off')
    plt.imshow(im)
plt.show()

上段に予測値、下段に正解値を並べています。
スクリーンショット 2023-07-12 2.17.20.png

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?