1.目的
以前の記事で、これまでに撮り貯めたデジカメ写真のExifを一覧化した。ここから、「何年のでもいいから5月28日の写真」とか、「このレンズ+このカメラで撮った写真」といった条件下での、いい写真を探し出したい。
このため画像認識AIを用いて、写真の採点をするAIを学習したい。前述のExif情報を貯め込んだMySQLデータベースから特定の条件の写真群を抽出し、これをAIで採点し、ベストな写真を選択したい。
2.環境
Windows 10 64bit
Anaconda (Python 3.6.10, 64bit)
- torch 1.9.0+cu111
- torchvision 0.10.0+cu111
NVIDIA GeForece RTX 2070 (CUDA 11.1 + CUDNN 8.1.1)
3.AIの学習
まずまず良いと思う写真を3000枚ほど集めた。苦行だわ。そんなにないわ。
これを「1_data_train/1_good」というフォルダに格納した。
次に、ゴミのような写真を1000枚ほど集めてきた。掃いて捨てるほどあるわ。
さらに、ゴミって程ではないが、良いとも言えない写真を2000枚ほど。
これらを「1_data_train/2_notgood」に格納した。
各写真をどちらのクラスに分類するかの線引きが難しいが、その曖昧さが最終的なAI採点において、0点でも100点でもない中間スコアを醸し出すのだと思う。迷うような写真は50点近辺でいいから、どちらに分類してもいいのだ。だから学習においてLossが思ったほど下がらなくても、それでいいのだと思う。
TorchvisionのImageNet学習済みAIをベースに使ったので、画像サイズは短辺224ピクセルで十分。だが、後々何か細工をするかもしれないと思ったので、今回は長辺600ピクセルのデータを準備した。なお、これは以下のコードには何ら影響しない。
optimizerはSAMを使った。コードは以下のgitより使わせていただいた。
https://github.com/davda54/sam
学習時のコードは以下。上記のsam.pyを同じフォルダに置く。
Torchvisionの学習済みResNet-50をベースに、最後のFC層を2分類に付け替えて、全レイヤーを再学習する。
import os
import glob
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import models, transforms
from sam import SAM
def main():
lr = 0.0005
momentum = 0.9
num_epochs = 200
batch_size = 64
label_names = ['1_good', '2_notgood']
num_classes = len(label_names)
early_stop = 20
model_name = 'rn50_photo1.pth'
data_dir = '1_data_train'
log_dir = 'torch_logs'
# ResNet-50の最終FC層を、2分類のFCに付け替える
net = models.resnet50(pretrained=True)
num_features = net.fc.in_features
net.fc = nn.Linear(num_features, num_classes)
# model_nameのファイルが既に存在すれば、そこから学習再開
continue_flag = 0
if os.path.exists(os.path.join(log_dir, model_name)):
net.load_state_dict(torch.load(os.path.join(log_dir, model_name)))
continue_flag = 1
# GPUが使えれば使う
if torch.cuda.is_available():
device = 'cuda'
print('GPU is available')
else:
device = 'cpu'
net = net.to(device)
# データセットの準備
train_dataset, valid_dataset = prepare_dataset(data_dir)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
# 最適化手法と最適化指標
criterion = nn.CrossEntropyLoss()
optimizer = SAM(net.parameters(), torch.optim.SGD, lr=lr, momentum=momentum)
# log準備 追記モード
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, model_name.replace('.pth','.csv'))
with open(log_file, mode='a') as f:
print('epoch,loss,val_loss,val_acc', file=f)
# 学習ループ
not_improved_count = 0
leaest_loss = 100
for epoch in range(num_epochs):
if epoch == 0 and continue_flag:
# 学習再開時の0エポック目は学習しない
loss = 1.0
else:
loss = train(net, criterion, optimizer, train_loader, device)
val_loss, val_acc = valid(net, criterion, valid_loader, device)
if val_loss < leaest_loss:
# val_lossが改善した時だけ重みを保存する
leaest_loss = val_loss
torch.save(net.state_dict(), os.path.join(log_dir, model_name))
not_improved_count = 0
saved = ',saved'
else:
not_improved_count += 1
saved = ''
print('%3d/%3d\tloss:%.3f / val_loss:%.3f / val_acc:%.3f%s' % (epoch, num_epochs, loss, val_loss, val_acc, saved))
with open(log_file, mode='a') as f:
print('%d,%.4f,%.4f,%.4f%s' % (epoch, loss, val_loss, val_acc, saved), file=f)
if early_stop and not_improved_count > early_stop:
# early_stop回以上、val_lossの改善がなかったら終了
break
return
####
def train(model, criterion, optimizer, train_loader, device):
model.train()
running_loss = 0.0
for images, labels in tqdm(train_loader):
images = images.to(device)
labels = labels.to(device)
# first forward-backward step
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.first_step(zero_grad=True)
running_loss += loss.item()
# second forward-backward step
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.second_step(zero_grad=True)
train_loss = running_loss / len(train_loader)
return train_loss
####
def valid(model, criterion, valid_loader, device):
model.eval()
running_loss = 0.0
correct = 0.0
total = 0.0
with torch.no_grad():
for images, labels in tqdm(valid_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
correct += torch.sum(predicted == labels.data)
total += labels.size(0)
val_loss = running_loss / len(valid_loader)
val_acc = correct / total
return val_loss, val_acc
####
def train_preprocess(img):
train_transform = [
transforms.ColorJitter(brightness=(0.7,1.3), contrast=(0.7,1.3), saturation=(0.7,1.3), hue=(-0.2,0.2)),
transforms.RandomAffine(20, translate=(0.1,0.1), scale=(1.0,1.1), shear=(-0.1,0.1)),
transforms.Resize(224),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
return transforms.Compose(train_transform)(img)
####
def test_preprocess(img):
test_transform = [
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
return transforms.Compose(test_transform)(img)
####
class MyDataset(torch.utils.data.Dataset):
def __init__(self, file_list, transform=None):
self.file_list = file_list
self.transform = transform
def __len__(self):
return len(self.file_list)
def __getitem__(self,index):
img_path = self.file_list[index]
img = Image.open(img_path).convert('RGB')
img_transformed = self.transform(img)
dirname = img_path.split('\\')[-2]
if dirname == '1_good':
# フォルダ名を決め打ちしているので注意
label = 0
else:
label = 1
return img_transformed, label
####
def prepare_dataset(data_dir):
image_file_list = glob.glob(data_dir+'/**/*.jpg', recursive=True)
image_file_list.extend(glob.glob(data_dir+'/**/*.jpeg', recursive=True))
print('N =', len(image_file_list))
# データセットをtrainとvalidationに分割
train_ratio = 0.75
train_size = int(train_ratio * len(image_file_list))
val_size = len(image_file_list) - train_size
data_train, data_val = torch.utils.data.random_split(image_file_list, [train_size, val_size])
train_dataset = MyDataset(file_list=data_train, transform=train_preprocess)
valid_dataset = MyDataset(file_list=data_val, transform=test_preprocess)
return train_dataset, valid_dataset
####
if __name__ == '__main__' :
main()
4.AIの推論
テストフォルダ内の写真(2_test_data/*.jpg)について、AI判定する。
import os
import glob
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
def main():
weights = 'torch_logs/rn50_photo1.pth'
test_dir = '2_test_data'
label_names = ['1_good', '2_notgood']
files = glob.glob(os.path.join(test_dir, '*.jpg'))
# NN準備
net = models.resnet50(pretrained=False)
num_features = net.fc.in_features
net.fc = nn.Linear(num_features, len(label_names))
# 学習したモデルのロード
net.load_state_dict(torch.load(weights))
net.eval()
# GPUが使えれば使う
use_gpu = torch.cuda.is_available()
if use_gpu:
print('GPU is available')
net = net.cuda()
print('N =', len(files))
for filename in files:
img = Image.open(filename).convert('RGB')
img_prerprocessed = test_preprocess(img)
img_prerprocessed = img_prerprocessed.unsqueeze_(0)
if use_gpu:
img_prerprocessed = img_prerprocessed.cuda()
outputs = net(img_prerprocessed).cpu()
softmax = torch.nn.functional.softmax(outputs.detach(), dim=1)[0].numpy()
# ソフトマックス後のclass 0の値を出力
print(filename, '\t%.2f'%(softmax[0] * 100))
return
test_preprocess = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
####
if __name__ == '__main__' :
main()
学習には用いていない適当なデータでテストする。0から始まるのはゴミまたはいまいちな写真、1から始まるのはそこそこ良いと思う写真。
2_test_data\0_DSC00001.JPG 0.49
2_test_data\0_DSC00011.JPG 36.32
2_test_data\0_DSC00021.JPG 2.33
2_test_data\0_DSC00025.JPG 19.52
2_test_data\1_DSCF4362.jpg 79.73
2_test_data\1_DSCF4376.jpg 83.37
2_test_data\1_DSCF6687.jpg 95.81
2_test_data\1_DSCF6698.jpg 88.38
2_test_data\1_DSCF6714.jpg 93.66
まずまず、いい感じに採点してくれるんじゃなかろうか。
5.条件にマッチする中から、いい写真を選ぶ
以下のように、MySQLに写真の撮影情報が保管されている、という前提。ここからsqlで絞り込んだ写真をAIで採点し、スコアの高い写真を選ぶこととする。
余談だが、上のようにExcel(32bit)からODBCで接続するには32bit版のMySQL ODBCドライバが必要。下のようにPython(64bit)のpyodbcからODBCで接続するには64bit版のMySQL ODBCドライバが必要だった。両方入れた。
import os
from PIL import Image
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
from torchvision import models, transforms
import pyodbc
####
def connect_db():
user = '****'
password = '****'
database = '****'
con = pyodbc.connect("DRIVER={MySQL ODBC 8.0 Unicode Driver};SERVER=localhost;" +
"UID=%s;PWD=%s;DATABASE=%s"%(user, password, database))
return con
####
def select_execute(con, sql):
cursor = con.cursor()
cursor.execute(sql)
rows = cursor.fetchall()
cursor.close()
return rows
### NN準備
weights = 'torch_logs/rn50_photo1.pth'
net = models.resnet50(pretrained=False)
num_features = net.fc.in_features
net.fc = nn.Linear(num_features, 2)
# 学習したモデルのロード
net.load_state_dict(torch.load(weights))
net.eval()
# GPUが使えれば使う
use_gpu = torch.cuda.is_available()
if use_gpu:
print('GPU is available')
net = net.cuda()
# preprocess
test_preprocess = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print('prepared')
#%% DBからSelect
# Lens名で検索
# sql ='SELECT Filepath FROM T_Photo_Exif WHERE LensModel like "XF56%"'
# 〇月×日で検索
sql ='SELECT Filepath FROM T_Photo_Exif \
WHERE MONTH(Datetime)=5 and DAYOFMONTH(Datetime)=28'
con = connect_db()
res = select_execute(con, sql)
con.close()
files = [r[0] for r in res]
print('N =', len(files))
print('fetched from db')
#%% AI推論
df = pd.DataFrame(columns=['Score'])
for filename in tqdm(files):
img = Image.open(filename).convert('RGB')
img_prerprocessed = test_preprocess(img)
img_prerprocessed = img_prerprocessed.unsqueeze_(0)
if use_gpu:
img_prerprocessed = img_prerprocessed.cuda()
outputs = net(img_prerprocessed).cpu()
softmax = torch.nn.functional.softmax(outputs.detach(), dim=1)[0].numpy()
df.at[filename, 'Score'] = softmax[0] * 100
print('inference finished')
#%%
# Scoreで降順ソート
df = df.sort_values('Score', ascending=False)
# Top 10を表示
print(df.iloc[:10])
気になるTop 10は?
Score
DSCF5044.JPG 97.4728
DSCF5045.JPG 97.1702
DSCF5043.JPG 96.6063
DSCF5047.JPG 96.563
DSCF4946.JPG 94.4338
DSCF4937.JPG 93.4808
DSCF5056.JPG 93.0035
DSCF5046.JPG 92.8627
DSCF4953.JPG 92.4802
DSCF4940.JPG 92.3308
AI氏のおすすめはこれ↓のようだ。奈良だなあ。悪くはないけどこれが97点か。
1~3位は連番で、どれもほぼ同じような写真だった。
ちなみにこの条件下で、私が自分で選んだ写真↓は60点だった。かろうじて合格。
気になる一番ゴミな写真は?・・・うん、確かに、これは2点だわ。
全247枚の点数分布をみるとこんな感じ。全体的には点数は甘めかな。
この点数をDBに書き込むことで、同じ写真については次回以降はAI判定が不要になる、みたいな仕組みも簡単にできると思うが、今回は省略。
6.まとめなど
- 自分でデジカメで撮った写真を「良い写真」、「良くない写真」に主観で分けて、それを画像認識AIで学習させた。
- DBに保管したExif情報を用いて、特定の条件で写真を絞り込んで、その条件下でのベストと思われる写真をAIに提案させた。
- AI採点の精度はまあまあ。ゴミ写真は正しく判定できそうだが、ベストな写真についてはAIと主観が一致するかは、学習方法や学習データ次第だろう。
- なお、224×224ピクセルに縮小してからAI判定しているので、この解像度ではぶれている、ボケているなどが正確には分からないだろうとは思う。PatchGANのDみたいに、画像をN × Nに分割して判定して、それを合成したものを最終判定値とするとかもありかもしれない。
- 学習時、画像を変形したり、色を変えたりといった一般的なData Augmentationを行っているが、「良い写真」というのはDA後も「良い写真」なのだろうか?
7.自分の写真を採点してみた
https://twitter.com/extra_heritage
に上げた写真を採点。絶対値はともかく、相対評価としては、8割がた納得いくかな。
80点未満の写真は、苦し紛れに投稿しているのが見透かされているわ。