pytorchで手書きスケッチの画像のクラス認識を実装してみます
普通の写真のクラス分類なら、学習済みのResnetなどを持って来てfine-tuneするのが手っ取り早いですが、今回の場合は、特徴量が普通の写真とはだいぶ異なるので、違うモデルを組みます
実装するモデル
https://github.com/HPrinz/sketch-recognition
これはKerasで書かれたモデルですが、これの Sketch-A-Net CNN
をPyTorchで実装し直します
他にも2種類モデルが書いてありますが、今回使うデータセットに対してもっとも精度が良いのがこれみたいです
データセット
データセットは以下からSketch dataset [png, ~500MB]
を落として来ました
http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/
今回はこれらのうち
butterfly, chair, dog, dragon, elephant, horse, pizza, race_car, ship, toilet
の10クラスを、
- train 各クラス80枚
- test 各クラス10枚
ずつに分けて、以下のようなフォルダ構成にします
root/
├ train/
│ ├ horse/
│ │ ├ 8537.png
│ │ └ ...
│ ├ butterfly/
│ │ ├ 2857.png
│ └ ...
├ test/
│ ├ horse/
│ │ ├ 8536.png
│ │ └ ...
│ ├ butterfly/
│ │ ├ 2856.png
│ └ ...
dataloader作成
torchvision.datasets.ImageFolder
に、画像がクラスごとに別れて格納されているrootディレクトリを指定すると、勝手に読み込んでくれます
import os
import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchvision.models as models
"""
https://github.com/HPrinz/sketch-recognition
に従い、inputサイズは(225, 225)にする
"""
data_transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize(225),
transforms.ToTensor(),
])
train_data = torchvision.datasets.ImageFolder(root='./train', transform=data_transform)
train_data_loader = torch.utils.data.DataLoader(train_data,
batch_size=4,
shuffle=True,
num_workers=4)
test_data = torchvision.datasets.ImageFolder(root='./test', transform=data_transform)
test_data_loader = torch.utils.data.DataLoader(test_data,
batch_size=4,
shuffle=True,
num_workers=4)
モデル作成
以下の通りに実装します
Input | Filter Size | Filter Num | Stride | Padding | Output | |
---|---|---|---|---|---|---|
0 | Conv | 225x225 | ||||
1 | Conv(ReLU) | 15x15 | 64 | 3 | 0 | 71x71 |
MaxPool | 3x3 | 2 | 0 | 35x35 | ||
2 | Conv(ReLU) | 3x3 | 128 | 1 | 0 | 31x31 |
MaxPool | 3x3 | 2 | 0 | 15x15 | ||
3 | Conv(ReLU) | 3x3 | 256 | 1 | 1 | 15x15 |
4 | Conv(ReLU) | 3x3 | 256 | 1 | 1 | 15x15 |
5 | Conv(ReLU) | 3x3 | 256 | 1 | 1 | 15x15 |
MaxPool | 3x3 | 2 | 0 | 7x7 | ||
6 | Conv(ReLU) | 7x7 | 512 | 1 | 0 | 1x1 |
Dropout 0.5 | 1x1 | |||||
7 | Conv(ReLU) | 1x1 | 512 | 1 | 0 | 1x1 |
Dropout 0.5 | 1x1 | |||||
8 | Conv (ReLU) | 1x1 | 250 | 1 | 0 | 1x1 |
class SketchModel(nn.Module):
"""
https://github.com/HPrinz/sketch-recognition
input size: (225. 225)
"""
def __init__(self):
super(SketchModel, self).__init__()
self.output_num = 10
self.filer_num = 64
self.act = nn.LeakyReLU(0.2)
self.conv1 = self.conv_block(1, self.filer_num*1, kernel_size=15, stride=3, padding=0, act_fn=self.act)
self.conv2 = self.conv_block(self.filer_num*1, self.filer_num*2, kernel_size=3, stride=1, padding=0, act_fn=self.act)
self.conv3 = self.conv_block(self.filer_num*2, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
self.conv4 = self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
self.conv5 = self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
self.conv6 = self.conv_block(self.filer_num*4, self.filer_num*8, kernel_size=7, stride=1, padding=0, act_fn=self.act)
self.conv7 = self.conv_block(self.filer_num*8, self.filer_num*8, kernel_size=1, stride=1, padding=0, act_fn=self.act)
self.conv8 = self.conv_block(self.filer_num*8, 50, kernel_size=1, stride=1, padding=0, act_fn=self.act)
self.out = nn.Sequential(
nn.Linear(50, self.output_num),
nn.Sigmoid(),
)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.drop = nn.Dropout2d(p=0.25)
def forward(self, input):
h = self.pool(self.conv1(input))
h = self.pool(self.conv2(h))
h = self.pool(self.conv5(self.conv4(self.conv3(h))))
h = self.drop(self.conv6(h))
h = self.conv8(self.drop(self.conv7(h)))
h = self.out(self.flatten(h))
return h
def conv_block(self, in_dim, out_dim, kernel_size, stride, padding, act_fn):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding),
act_fn,
)
def flatten(self, x):
bs = x.size()[0]
return x.view(bs, -1)
学習実行
model = SketchModel().cuda() #cpuではcudaを外す
lr = 1e-4
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss().cuda() #cpuではcudaを外す
def train(epoch, name):
if not os.path.exists('models/' + name):
os.mkdir('models/'+name)
train_loss = np.array([])
test_loss = np.array([])
for i in range(epoch):
loss_per_epoch = 0
acc = 0
for batch_idx, (imgs, labels) in enumerate(train_data_loader):
model.train()
optim.zero_grad()
imgs = imgs.cuda().float() #cpuではcudaを外す
labels = labels.cuda().long()
estimated = model.forward(imgs)
loss = criterion(estimated, labels)
loss.backward()
optim.step()
loss_per_epoch += loss.data
acc += torch.sum(labels == torch.argmax(estimated, dim=1)).cpu().numpy()
train_loss = np.append(train_loss, loss_per_epoch)
print("epoch: {}, train_loss: {}".format(i, train_loss[-1]))
print("train_acc: {}".format(acc/len(train_data)))
loss_per_epoch = 0
acc = 0
for batch_idx, (imgs, labels) in enumerate(test_data_loader):
model.eval()
imgs = imgs.cuda().float() #cpuではcudaを外す
labels = labels.cuda().long()
estimated = model.forward(imgs)
loss = criterion(estimated, labels)
loss_per_epoch += loss.data*len(train_data)/len(test_data)
acc += torch.sum(labels == torch.argmax(estimated, dim=1)).cpu().numpy()
test_loss = np.append(test_loss, loss_per_epoch)
print("epoch: {}, test_loss: {}".format(i, test_loss[-1]))
print("test_acc: {}".format(acc/len(test_data)))
train(40, 'SketchModel')
モデル再定義
上のモデルで40epoch回した結果ですが、テストデータの正解率test_acc
は0.65くらいしか出ませんでした。
そのため、
-
BatchNormalization
を各層に加える - 部分的にResNetの構造を入れる
ようにしたところ、test_acc
が0.87まで伸びました
0.9までは行きたいですが、とりあえずこの辺で
以下が再定義したモデルとその学習です
class SketchResModel(nn.Module):
"""
https://github.com/HPrinz/sketch-recognition
input size: (225. 225)
"""
def __init__(self):
super(SketchResModel, self).__init__()
self.output_num = 10
self.filer_num = 64
self.act = nn.LeakyReLU(0.2)
self.conv1 = self.conv_block(1, self.filer_num*1, kernel_size=15, stride=3, padding=0, act_fn=self.act)
self.conv2 = self.conv_block(self.filer_num*1, self.filer_num*2, kernel_size=3, stride=1, padding=0, act_fn=self.act)
self.conv3 = self.conv_block(self.filer_num*2, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
self.conv4 = self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
self.conv5 = self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
self.conv6 = self.conv_block(self.filer_num*4, self.filer_num*8, kernel_size=7, stride=1, padding=0, act_fn=self.act)
self.conv7 = self.conv_block(self.filer_num*8, self.filer_num*8, kernel_size=1, stride=1, padding=0, act_fn=self.act)
self.conv8 = self.conv_block(self.filer_num*8, 50, kernel_size=1, stride=1, padding=0, act_fn=self.act)
self.out = nn.Sequential(
nn.Linear(50, self.output_num),
nn.Sigmoid(),
)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.drop = nn.Dropout2d(p=0.25)
self.res_block = nn.Sequential(
self.conv_block(self.filer_num*2, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act),
self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act),
nn.Conv2d(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(self.filer_num*4),
)
self.residual = nn.Sequential(
nn.Conv2d(self.filer_num*2, self.filer_num*4, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(self.filer_num*4),
)
def forward(self, input):
h = self.pool(self.conv1(input))
h = self.pool(self.conv2(h))
h1 = self.residual(h)
h = self.pool(self.act(self.res_block(h) + h1))
h = self.drop(self.conv6(h))
h = self.conv8(self.drop(self.conv7(h)))
h = self.out(self.flatten(h))
return h
def conv_block(self, in_dim, out_dim, kernel_size, stride, padding, act_fn):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding),
nn.BatchNorm2d(out_dim),
act_fn,
)
def flatten(self, x):
bs = x.size()[0]
return x.view(bs, -1)
data_transform2 = transforms.Compose([
transforms.Grayscale(),
transforms.Resize(225),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
train_data = torchvision.datasets.ImageFolder(root='./train', transform=data_transform2)
train_data_loader = torch.utils.data.DataLoader(train_data,
batch_size=4,
shuffle=True,
num_workers=4)
test_data = torchvision.datasets.ImageFolder(root='./test', transform=data_transform2)
test_data_loader = torch.utils.data.DataLoader(test_data,
batch_size=4,
shuffle=True,
num_workers=4)
model = SketchModel().cuda() #cpuではcudaを外す
lr = 1e-4
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss().cuda() #cpuではcudaを外す
train(40, 'SketchResModel')