LoginSignup
1
0

More than 1 year has passed since last update.

Python+画像認識AIでデジカメ写真を採点したい

Last updated at Posted at 2022-05-29

1.目的

以前の記事で、これまでに撮り貯めたデジカメ写真のExifを一覧化した。ここから、「何年のでもいいから5月28日の写真」とか、「このレンズ+このカメラで撮った写真」といった条件下での、いい写真を探し出したい。
このため画像認識AIを用いて、写真の採点をするAIを学習したい。前述のExif情報を貯め込んだMySQLデータベースから特定の条件の写真群を抽出し、これをAIで採点し、ベストな写真を選択したい。

2.環境

Windows 10 64bit
Anaconda (Python 3.6.10, 64bit)

  • torch 1.9.0+cu111
  • torchvision 0.10.0+cu111

NVIDIA GeForece RTX 2070 (CUDA 11.1 + CUDNN 8.1.1)

3.AIの学習

まずまず良いと思う写真を3000枚ほど集めた。苦行だわ。そんなにないわ。
これを「1_data_train/1_good」というフォルダに格納した。

次に、ゴミのような写真を1000枚ほど集めてきた。掃いて捨てるほどあるわ。
さらに、ゴミって程ではないが、良いとも言えない写真を2000枚ほど。
これらを「1_data_train/2_notgood」に格納した。

各写真をどちらのクラスに分類するかの線引きが難しいが、その曖昧さが最終的なAI採点において、0点でも100点でもない中間スコアを醸し出すのだと思う。迷うような写真は50点近辺でいいから、どちらに分類してもいいのだ。だから学習においてLossが思ったほど下がらなくても、それでいいのだと思う。

TorchvisionのImageNet学習済みAIをベースに使ったので、画像サイズは短辺224ピクセルで十分。だが、後々何か細工をするかもしれないと思ったので、今回は長辺600ピクセルのデータを準備した。なお、これは以下のコードには何ら影響しない。

optimizerはSAMを使った。コードは以下のgitより使わせていただいた。

https://github.com/davda54/sam

学習時のコードは以下。上記のsam.pyを同じフォルダに置く。
Torchvisionの学習済みResNet-50をベースに、最後のFC層を2分類に付け替えて、全レイヤーを再学習する。

import os
import glob
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import models, transforms

from sam import SAM

def main():
    lr = 0.0005
    momentum = 0.9

    num_epochs = 200
    batch_size = 64

    label_names = ['1_good', '2_notgood']
    num_classes = len(label_names)

    early_stop = 20

    model_name = 'rn50_photo1.pth'

    data_dir = '1_data_train'
    log_dir = 'torch_logs'
    
    # ResNet-50の最終FC層を、2分類のFCに付け替える
    net = models.resnet50(pretrained=True)
    num_features = net.fc.in_features
    net.fc = nn.Linear(num_features, num_classes)

    # model_nameのファイルが既に存在すれば、そこから学習再開
    continue_flag = 0
    if os.path.exists(os.path.join(log_dir, model_name)):
        net.load_state_dict(torch.load(os.path.join(log_dir, model_name)))
        continue_flag = 1

    # GPUが使えれば使う
    if torch.cuda.is_available():
        device = 'cuda'
        print('GPU is available')
    else:
        device = 'cpu'

    net = net.to(device)

    # データセットの準備
    train_dataset, valid_dataset = prepare_dataset(data_dir)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

    # 最適化手法と最適化指標
    criterion = nn.CrossEntropyLoss()
    optimizer = SAM(net.parameters(), torch.optim.SGD, lr=lr, momentum=momentum)

    # log準備 追記モード
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, model_name.replace('.pth','.csv'))
    with open(log_file, mode='a') as f:
        print('epoch,loss,val_loss,val_acc', file=f)

    # 学習ループ
    not_improved_count = 0
    leaest_loss = 100
    for epoch in range(num_epochs):
        if epoch == 0 and continue_flag:
            # 学習再開時の0エポック目は学習しない
            loss = 1.0
        else:
            loss = train(net, criterion, optimizer, train_loader, device)
        val_loss, val_acc = valid(net, criterion, valid_loader, device)

        if val_loss < leaest_loss:
            # val_lossが改善した時だけ重みを保存する
            leaest_loss = val_loss
            torch.save(net.state_dict(), os.path.join(log_dir, model_name))
            not_improved_count = 0
            saved = ',saved'
        else:
            not_improved_count += 1
            saved = ''
        
        print('%3d/%3d\tloss:%.3f / val_loss:%.3f / val_acc:%.3f%s' % (epoch, num_epochs, loss, val_loss, val_acc, saved))
        
        with open(log_file, mode='a') as f:
            print('%d,%.4f,%.4f,%.4f%s' % (epoch, loss, val_loss, val_acc, saved), file=f)

        if early_stop and not_improved_count > early_stop:
            # early_stop回以上、val_lossの改善がなかったら終了
            break
    return

####
def train(model, criterion, optimizer, train_loader, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # first forward-backward step
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.first_step(zero_grad=True)
        running_loss += loss.item()

        # second forward-backward step
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.second_step(zero_grad=True)
    train_loss = running_loss / len(train_loader)
    return train_loss

####
def valid(model, criterion, valid_loader, device):
    model.eval()
    running_loss = 0.0
    correct = 0.0
    total = 0.0
    with torch.no_grad():
        for images, labels in tqdm(valid_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            correct += torch.sum(predicted == labels.data)
            total += labels.size(0)
    val_loss = running_loss / len(valid_loader)
    val_acc = correct / total
    return val_loss, val_acc

####
def train_preprocess(img):
    train_transform = [
        transforms.ColorJitter(brightness=(0.7,1.3), contrast=(0.7,1.3), saturation=(0.7,1.3), hue=(-0.2,0.2)),
        transforms.RandomAffine(20, translate=(0.1,0.1), scale=(1.0,1.1), shear=(-0.1,0.1)),
        transforms.Resize(224),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    return transforms.Compose(train_transform)(img)

####
def test_preprocess(img):
    test_transform = [
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    return transforms.Compose(test_transform)(img)

####
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform
    def __len__(self):
        return len(self.file_list)
    def __getitem__(self,index):
        img_path = self.file_list[index]
        img = Image.open(img_path).convert('RGB')
        img_transformed = self.transform(img)
        dirname = img_path.split('\\')[-2]
        if dirname == '1_good':
            # フォルダ名を決め打ちしているので注意
            label = 0
        else:
            label = 1
        return img_transformed, label

####
def prepare_dataset(data_dir):
    image_file_list = glob.glob(data_dir+'/**/*.jpg', recursive=True)
    image_file_list.extend(glob.glob(data_dir+'/**/*.jpeg', recursive=True))

    print('N =', len(image_file_list))

    # データセットをtrainとvalidationに分割
    train_ratio = 0.75
    train_size = int(train_ratio * len(image_file_list))
    val_size = len(image_file_list) - train_size
    data_train, data_val = torch.utils.data.random_split(image_file_list, [train_size, val_size])

    train_dataset = MyDataset(file_list=data_train, transform=train_preprocess)
    valid_dataset = MyDataset(file_list=data_val, transform=test_preprocess)
    return train_dataset, valid_dataset

####
if __name__ ==  '__main__' :
    main()

GPUで1時間ほど学習ののち、ちょっと早めだが打ち切った。
image.png

4.AIの推論

テストフォルダ内の写真(2_test_data/*.jpg)について、AI判定する。

import os
import glob
from PIL import Image

import torch
import torch.nn as nn
from torchvision import models, transforms

def main():
    weights = 'torch_logs/rn50_photo1.pth'
    test_dir = '2_test_data'
    label_names = ['1_good', '2_notgood']

    files = glob.glob(os.path.join(test_dir, '*.jpg'))

    # NN準備
    net = models.resnet50(pretrained=False)
    num_features = net.fc.in_features
    net.fc = nn.Linear(num_features, len(label_names))

    # 学習したモデルのロード
    net.load_state_dict(torch.load(weights))
    net.eval()

    # GPUが使えれば使う
    use_gpu = torch.cuda.is_available()
    if use_gpu:
        print('GPU is available')
        net = net.cuda()

    print('N =', len(files))
    for filename in files:
        img = Image.open(filename).convert('RGB')
        img_prerprocessed = test_preprocess(img)
        img_prerprocessed = img_prerprocessed.unsqueeze_(0)
        if use_gpu:
            img_prerprocessed = img_prerprocessed.cuda()
        outputs = net(img_prerprocessed).cpu()
        softmax = torch.nn.functional.softmax(outputs.detach(), dim=1)[0].numpy()

        # ソフトマックス後のclass 0の値を出力
        print(filename, '\t%.2f'%(softmax[0] * 100))
    return

test_preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

####
if __name__ ==  '__main__' :
    main()

学習には用いていない適当なデータでテストする。0から始まるのはゴミまたはいまいちな写真、1から始まるのはそこそこ良いと思う写真。

2_test_data\0_DSC00001.JPG      0.49
2_test_data\0_DSC00011.JPG      36.32
2_test_data\0_DSC00021.JPG      2.33
2_test_data\0_DSC00025.JPG      19.52
2_test_data\1_DSCF4362.jpg      79.73
2_test_data\1_DSCF4376.jpg      83.37
2_test_data\1_DSCF6687.jpg      95.81
2_test_data\1_DSCF6698.jpg      88.38
2_test_data\1_DSCF6714.jpg      93.66

まずまず、いい感じに採点してくれるんじゃなかろうか。

5.条件にマッチする中から、いい写真を選ぶ

以下のように、MySQLに写真の撮影情報が保管されている、という前提。ここからsqlで絞り込んだ写真をAIで採点し、スコアの高い写真を選ぶこととする。
image.png

余談だが、上のようにExcel(32bit)からODBCで接続するには32bit版のMySQL ODBCドライバが必要。下のようにPython(64bit)のpyodbcからODBCで接続するには64bit版のMySQL ODBCドライバが必要だった。両方入れた。

import os
from PIL import Image
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
from torchvision import models, transforms
import pyodbc

####
def connect_db():
    user = '****'
    password = '****'
    database = '****'
    con = pyodbc.connect("DRIVER={MySQL ODBC 8.0 Unicode Driver};SERVER=localhost;" +
                            "UID=%s;PWD=%s;DATABASE=%s"%(user, password, database))
    return con
####
def select_execute(con, sql):
    cursor = con.cursor()
    cursor.execute(sql)
    rows = cursor.fetchall()
    cursor.close()
    return rows

### NN準備
weights = 'torch_logs/rn50_photo1.pth'
net = models.resnet50(pretrained=False)
num_features = net.fc.in_features
net.fc = nn.Linear(num_features, 2)

# 学習したモデルのロード
net.load_state_dict(torch.load(weights))
net.eval()

# GPUが使えれば使う
use_gpu = torch.cuda.is_available()
if use_gpu:
    print('GPU is available')
    net = net.cuda()

# preprocess
test_preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print('prepared')

#%% DBからSelect
# Lens名で検索
# sql ='SELECT Filepath FROM T_Photo_Exif WHERE LensModel like "XF56%"'
# 〇月×日で検索
sql ='SELECT Filepath FROM T_Photo_Exif \
      WHERE MONTH(Datetime)=5 and DAYOFMONTH(Datetime)=28'
con = connect_db()
res = select_execute(con, sql)
con.close()
files = [r[0] for r in res]
print('N =', len(files))
print('fetched from db')

#%% AI推論
df = pd.DataFrame(columns=['Score'])
for filename in tqdm(files):
    img = Image.open(filename).convert('RGB')
    img_prerprocessed = test_preprocess(img)
    img_prerprocessed = img_prerprocessed.unsqueeze_(0)
    if use_gpu:
        img_prerprocessed = img_prerprocessed.cuda()
    outputs = net(img_prerprocessed).cpu()
    softmax = torch.nn.functional.softmax(outputs.detach(), dim=1)[0].numpy()
    df.at[filename, 'Score'] = softmax[0] * 100
print('inference finished')

#%%
# Scoreで降順ソート
df = df.sort_values('Score', ascending=False)
# Top 10を表示
print(df.iloc[:10])

気になるTop 10は?

              Score
DSCF5044.JPG  97.4728
DSCF5045.JPG  97.1702
DSCF5043.JPG  96.6063
DSCF5047.JPG   96.563
DSCF4946.JPG  94.4338
DSCF4937.JPG  93.4808
DSCF5056.JPG  93.0035
DSCF5046.JPG  92.8627
DSCF4953.JPG  92.4802
DSCF4940.JPG  92.3308

AI氏のおすすめはこれ↓のようだ。奈良だなあ。悪くはないけどこれが97点か。
1~3位は連番で、どれもほぼ同じような写真だった。
image.png

ちなみにこの条件下で、私が自分で選んだ写真↓は60点だった。かろうじて合格。
image.png

気になる一番ゴミな写真は?・・・うん、確かに、これは2点だわ。
image.png

全247枚の点数分布をみるとこんな感じ。全体的には点数は甘めかな。
image.png

この点数をDBに書き込むことで、同じ写真については次回以降はAI判定が不要になる、みたいな仕組みも簡単にできると思うが、今回は省略。

6.まとめなど

  • 自分でデジカメで撮った写真を「良い写真」、「良くない写真」に主観で分けて、それを画像認識AIで学習させた。
  • DBに保管したExif情報を用いて、特定の条件で写真を絞り込んで、その条件下でのベストと思われる写真をAIに提案させた。
  • AI採点の精度はまあまあ。ゴミ写真は正しく判定できそうだが、ベストな写真についてはAIと主観が一致するかは、学習方法や学習データ次第だろう。
  • なお、224×224ピクセルに縮小してからAI判定しているので、この解像度ではぶれている、ボケているなどが正確には分からないだろうとは思う。PatchGANのDみたいに、画像をN × Nに分割して判定して、それを合成したものを最終判定値とするとかもありかもしれない。
  • 学習時、画像を変形したり、色を変えたりといった一般的なData Augmentationを行っているが、「良い写真」というのはDA後も「良い写真」なのだろうか?

7.自分の写真を採点してみた

https://twitter.com/extra_heritage
に上げた写真を採点。絶対値はともかく、相対評価としては、8割がた納得いくかな。
80点未満の写真は、苦し紛れに投稿しているのが見透かされているわ。
image.png

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0