13
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

[PyTorch] 手書きイラストのクラス認識をpytorchで実装する

Last updated at Posted at 2018-12-04

pytorchで手書きスケッチの画像のクラス認識を実装してみます
普通の写真のクラス分類なら、学習済みのResnetなどを持って来てfine-tuneするのが手っ取り早いですが、今回の場合は、特徴量が普通の写真とはだいぶ異なるので、違うモデルを組みます

実装するモデル

https://github.com/HPrinz/sketch-recognition
これはKerasで書かれたモデルですが、これの Sketch-A-Net CNNをPyTorchで実装し直します
他にも2種類モデルが書いてありますが、今回使うデータセットに対してもっとも精度が良いのがこれみたいです

データセット

データセットは以下からSketch dataset [png, ~500MB]を落として来ました
http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/

↓こんな画像が各クラス80枚ずつあります

今回はこれらのうち
butterfly, chair, dog, dragon, elephant, horse, pizza, race_car, ship, toilet
の10クラスを、

  • train 各クラス80枚
  • test 各クラス10枚

ずつに分けて、以下のようなフォルダ構成にします

root/
 ├ train/
 │ ├ horse/
 │ │  ├ 8537.png
 │ │  └ ...
 │ ├ butterfly/
 │ │  ├ 2857.png
 │    └ ... 
 ├ test/
 │ ├ horse/
 │ │  ├ 8536.png
 │ │  └ ...
 │ ├ butterfly/
 │ │  ├ 2856.png
 │    └ ... 

dataloader作成

torchvision.datasets.ImageFolderに、画像がクラスごとに別れて格納されているrootディレクトリを指定すると、勝手に読み込んでくれます


import os
import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchvision.models as models
"""
https://github.com/HPrinz/sketch-recognition
に従い、inputサイズは(225, 225)にする
"""
data_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(225),
    transforms.ToTensor(),
])

train_data = torchvision.datasets.ImageFolder(root='./train', transform=data_transform)
train_data_loader = torch.utils.data.DataLoader(train_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=4)    
test_data = torchvision.datasets.ImageFolder(root='./test', transform=data_transform)
test_data_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=4)

モデル作成

以下の通りに実装します

Input Filter Size Filter Num Stride Padding Output
0 Conv 225x225
1 Conv(ReLU) 15x15 64 3 0 71x71
MaxPool 3x3 2 0 35x35
2 Conv(ReLU) 3x3 128 1 0 31x31
MaxPool 3x3 2 0 15x15
3 Conv(ReLU) 3x3 256 1 1 15x15
4 Conv(ReLU) 3x3 256 1 1 15x15
5 Conv(ReLU) 3x3 256 1 1 15x15
MaxPool 3x3 2 0 7x7
6 Conv(ReLU) 7x7 512 1 0 1x1
Dropout 0.5 1x1
7 Conv(ReLU) 1x1 512 1 0 1x1
Dropout 0.5 1x1
8 Conv (ReLU) 1x1 250 1 0 1x1
class SketchModel(nn.Module):
    """
    https://github.com/HPrinz/sketch-recognition
    input size: (225. 225)
    """
    def __init__(self):
        super(SketchModel, self).__init__()
        self.output_num = 10
        self.filer_num = 64
        self.act = nn.LeakyReLU(0.2)
        self.conv1 = self.conv_block(1, self.filer_num*1, kernel_size=15, stride=3, padding=0, act_fn=self.act)
        self.conv2 = self.conv_block(self.filer_num*1, self.filer_num*2, kernel_size=3, stride=1, padding=0, act_fn=self.act)
        self.conv3 = self.conv_block(self.filer_num*2, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
        self.conv4 = self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
        self.conv5 = self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
        self.conv6 = self.conv_block(self.filer_num*4, self.filer_num*8, kernel_size=7, stride=1, padding=0, act_fn=self.act)
        self.conv7 = self.conv_block(self.filer_num*8, self.filer_num*8, kernel_size=1, stride=1, padding=0, act_fn=self.act)
        
        self.conv8 = self.conv_block(self.filer_num*8, 50, kernel_size=1, stride=1, padding=0, act_fn=self.act)
        self.out = nn.Sequential(
            nn.Linear(50, self.output_num),
            nn.Sigmoid(),
        )
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        self.drop = nn.Dropout2d(p=0.25)
        
    def forward(self, input):
        h = self.pool(self.conv1(input))
        h = self.pool(self.conv2(h))
        h = self.pool(self.conv5(self.conv4(self.conv3(h))))
        h = self.drop(self.conv6(h))
        h = self.conv8(self.drop(self.conv7(h)))
        h = self.out(self.flatten(h))
        return h
    
    def conv_block(self, in_dim, out_dim, kernel_size, stride, padding, act_fn):
        return nn.Sequential(
            nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding),
            act_fn,
        )
    
    def flatten(self, x):
        bs = x.size()[0]
        return x.view(bs, -1)

学習実行


model = SketchModel().cuda() #cpuではcudaを外す
lr = 1e-4
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss().cuda() #cpuではcudaを外す

def train(epoch, name):
    
    if not os.path.exists('models/' + name):
        os.mkdir('models/'+name)
    train_loss = np.array([])
    test_loss = np.array([])
    for i in range(epoch):
        
        loss_per_epoch = 0
        acc = 0
        for batch_idx, (imgs, labels) in enumerate(train_data_loader):
            model.train()
            optim.zero_grad()
            imgs = imgs.cuda().float() #cpuではcudaを外す
            labels = labels.cuda().long()
            estimated = model.forward(imgs)
            loss = criterion(estimated, labels)
            loss.backward()
            optim.step()
            loss_per_epoch += loss.data
            acc += torch.sum(labels == torch.argmax(estimated, dim=1)).cpu().numpy()
            
        train_loss = np.append(train_loss, loss_per_epoch)
        print("epoch: {}, train_loss: {}".format(i, train_loss[-1]))
        print("train_acc: {}".format(acc/len(train_data)))
    
        loss_per_epoch = 0
        acc = 0
        for batch_idx, (imgs, labels) in enumerate(test_data_loader):
            model.eval()
            imgs = imgs.cuda().float() #cpuではcudaを外す
            labels = labels.cuda().long()
            estimated = model.forward(imgs)
            loss = criterion(estimated, labels)
            loss_per_epoch += loss.data*len(train_data)/len(test_data)
            acc += torch.sum(labels == torch.argmax(estimated, dim=1)).cpu().numpy()
            
        test_loss = np.append(test_loss, loss_per_epoch)
        print("epoch: {}, test_loss: {}".format(i, test_loss[-1]))
        print("test_acc: {}".format(acc/len(test_data)))

train(40, 'SketchModel')

モデル再定義

上のモデルで40epoch回した結果ですが、テストデータの正解率test_accは0.65くらいしか出ませんでした。
そのため、

  • BatchNormalizationを各層に加える
  • 部分的にResNetの構造を入れる

ようにしたところ、test_accが0.87まで伸びました
0.9までは行きたいですが、とりあえずこの辺で
以下が再定義したモデルとその学習です

class SketchResModel(nn.Module):
    """
    https://github.com/HPrinz/sketch-recognition
    input size: (225. 225)
    """
    def __init__(self):
        super(SketchResModel, self).__init__()
        self.output_num = 10
        self.filer_num = 64
        self.act = nn.LeakyReLU(0.2)
        self.conv1 = self.conv_block(1, self.filer_num*1, kernel_size=15, stride=3, padding=0, act_fn=self.act)
        self.conv2 = self.conv_block(self.filer_num*1, self.filer_num*2, kernel_size=3, stride=1, padding=0, act_fn=self.act)
        self.conv3 = self.conv_block(self.filer_num*2, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
        self.conv4 = self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
        self.conv5 = self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act)
        self.conv6 = self.conv_block(self.filer_num*4, self.filer_num*8, kernel_size=7, stride=1, padding=0, act_fn=self.act)
        self.conv7 = self.conv_block(self.filer_num*8, self.filer_num*8, kernel_size=1, stride=1, padding=0, act_fn=self.act)
        self.conv8 = self.conv_block(self.filer_num*8, 50, kernel_size=1, stride=1, padding=0, act_fn=self.act)
        self.out = nn.Sequential(
            nn.Linear(50, self.output_num),
            nn.Sigmoid(),
        )
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        self.drop = nn.Dropout2d(p=0.25)
        
        self.res_block = nn.Sequential(
            self.conv_block(self.filer_num*2, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act),
            self.conv_block(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1, act_fn=self.act),
            nn.Conv2d(self.filer_num*4, self.filer_num*4, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.filer_num*4),
        )
        self.residual = nn.Sequential(
            nn.Conv2d(self.filer_num*2, self.filer_num*4, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(self.filer_num*4),
        )
        
    def forward(self, input):
        h = self.pool(self.conv1(input))
        h = self.pool(self.conv2(h))
        h1 = self.residual(h)
        h = self.pool(self.act(self.res_block(h) + h1))
        h = self.drop(self.conv6(h))
        h = self.conv8(self.drop(self.conv7(h)))
        h = self.out(self.flatten(h))
        return h
    
    def conv_block(self, in_dim, out_dim, kernel_size, stride, padding, act_fn):
        return nn.Sequential(
            nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding),
            nn.BatchNorm2d(out_dim),
            act_fn,
        )
    
    def flatten(self, x):
        bs = x.size()[0]
        return x.view(bs, -1)

data_transform2 = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(225),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

train_data = torchvision.datasets.ImageFolder(root='./train', transform=data_transform2)
train_data_loader = torch.utils.data.DataLoader(train_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=4)    
test_data = torchvision.datasets.ImageFolder(root='./test', transform=data_transform2)
test_data_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=4)
model = SketchModel().cuda() #cpuではcudaを外す
lr = 1e-4
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss().cuda() #cpuではcudaを外す
train(40, 'SketchResModel')
13
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?