5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【pytorch入門】torchvision.modelsを使ってfine-tuningで遊んでみた♬入力画像がGrayの場合

Last updated at Posted at 2021-02-15

昨夜に引き続き、Fine-tuningをやってみました。
今夜は、通常のものと異なり、入力画像としてGray画像を利用します。
Gray⇒カテゴライズします。

人の目だとまあ、分かるわな~
第一層をGrayつまり1Chに対応させればいいので、もう簡単にまとめようと思います。
###やったこと
・入力画像をGray画像に変換
・第一層を1Chにする
・classificationしてみると
###・入力画像をGray画像に変換
これは、Cifar10読込時に以下の処理を行わせればOKです。
昨夜のコードからの変更点は、transforms.Grayscale()としています。

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((self.width, self.height)),
            transforms.Grayscale()
        ])

そして、以下の読込時に上記のtransformを利用しますが、変更はありません。

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            cifar_full =CIFAR10(self.data_dir, train=True, transform=self.transform)
            n_train = int(len(cifar_full)*0.8)
            n_val = len(cifar_full)-n_train
            self.cifar_train, self.cifar_val = torch.utils.data.random_split(cifar_full, [n_train, n_val])
        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

###・第一層を1Chにする
ここが今回の肝です。
以下のコードを見てください。
これは、先日の参考①でGray画像の色付けで利用したコードです。

self.model_ = models.resnet18(pretrained=True)
# Change first conv layer to accept single-channel (grayscale) input
self.model_.conv1.weight = nn.Parameter(self.model_.conv1.weight.sum(dim=1).unsqueeze(1)) 

【参考】
Image Colorization with Convolutional Neural Networks
参考①では、以下のように使われていました。
上記のコードでは、最初の二行を利用しています。
そして、以下のコードでは必要な部分を次の行で取得しています。
今回は、カラー画像のFine-tuningと同様、モデル全体を利用することとしています。

    ## First half: ResNet
    resnet = models.resnet18(num_classes=365) 
    # Change first conv layer to accept single-channel (grayscale) input
    resnet.conv1.weight = nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1)) 
    # Extract midlevel features from ResNet-gray
    self.midlevel_resnet = nn.Sequential(*list(resnet.children())[0:6])

そこに以下のfcを追加して全体を構成しています。

self.f_resnet = nn.Sequential(
    nn.Linear(in_features=np.int(1000), out_features=10, bias=True) 
)   

ところで、以下はどういう意味でしょう。
self.model_.conv1.weight = nn.Parameter(self.model_.conv1.weight.sum(dim=1).unsqueeze(1))
それを見るために、resnet18の元々の構造を見てみましょう。
以下のとおりです。
これを見ると一目瞭然、self.model_.conv1.weight.sum(dim=).unsqueeze(1)の前半のself.model_.conv1.weightは、構成要素のResNet((conv1):...のconv1を表していることが分かります。

resnet18 Net(
  (net): customize_model(
    (model_): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
...

一方、他のnetworkではどうなっているでしょう。
densenet121の初めの方は、以下のとおりです。
これの入力の指定は、features.conv0で行けそうです。

densenet121 Net(
  (net): customize_model(
    (model_): DenseNet(
      (features): Sequential(
        (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu0): ReLU(inplace=True)

次に、mobilev2は
少し複雑ですが、features[0][0]で良さそうです。

mobilev2 Net(
  (net): customize_model(
    (model_): MobileNetV2(
      (features): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )

最後にvgg16_bnは以下のようになっています。
なんかfeaturesが無いと指定できないような気がしましたが、この場合は
まず、以下のようにnetworkを定義します。

model_0 = models.vgg16_bn(pretrained=True) 
self.model_ = nn.Sequential(*list(model_0.children())[0]) #VGG16_bn 

とすると以下の出力が得られます。

vgg16 Net(
  (net): customize_model(
    (model_): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)

そして、self.model_[0]で指定できました。
こうして、まとめると以下のようなnetworkクラスが定義できます。

class customize_model(nn.Module):
  def __init__(self, input_size=128, sel_model='resnet18'):
    super(customize_model, self).__init__()
    
    ## Select model
    if sel_model == 'resnext50':
        self.model_ = models.resnext50_32x4d(pretrained=True)
        self.model_.conv1.weight = nn.Parameter(self.model_.conv1.weight.sum(dim=1).unsqueeze(1))
    elif sel_model == 'vgg16':
        model_0 = models.vgg16_bn(pretrained=True) 
        self.model_ = nn.Sequential(*list(model_0.children())[0]) 
        self.model_[0].weight = nn.Parameter(self.model_[0].weight.sum(dim=1).unsqueeze(1))
    elif sel_model == 'wide50':
        self.model_ = models.wide_resnet50_2(pretrained=True)
        self.model_.conv1.weight = nn.Parameter(self.model_.conv1.weight.sum(dim=1).unsqueeze(1))
    elif sel_model == 'mobilev2':
        self.model_ = models.mobilenet_v2(pretrained=True)
        self.model_.features[0][0].weight = nn.Parameter(self.model_.features[0][0].weight.sum(dim=1).unsqueeze(1))
    elif sel_model == 'densenet121':
        self.model_ = models.densenet121(pretrained=True)
        self.model_.features.conv0.weight = nn.Parameter(self.model_.features.conv0.weight.sum(dim=1).unsqueeze(1))
    else:
        self.model_ = models.resnet18(pretrained=True)
        self.model_.conv1.weight = nn.Parameter(self.model_.conv1.weight.sum(dim=1).unsqueeze(1))
       
    
    for i, param in enumerate(self.model_.parameters()):
        param.requires_grad = True #False
        print(i, param.requires_grad)

    if sel_model =='vgg16':
        self.f_resnet = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=np.int(512*4*4), out_features=2048, bias=True), #wide_resnet50-2 512*256
            nn.ReLU(),
            nn.Dropout(p=0.2, inplace=False),
            nn.Linear(in_features=np.int(512*2*2), out_features=1000, bias=True), #wide_resnet50-2 512*256
            nn.Linear(in_features=np.int(1000), out_features=10, bias=True) #wide_resnet50-2 512*256
            )    
    else:
        self.f_resnet = nn.Sequential(
            nn.Linear(in_features=np.int(1000), out_features=10, bias=True) #wide_resnet50-2 512*256
            )   

  def forward(self, input):
    midlevel_features = self.model_(input)
    output = self.f_resnet(midlevel_features)
    return output

###・classificationしてみると
以下の結果をえました。右側にカラー画像で実施した昨夜の結果を貼っています。
カラー画像の結果を比較すると、学習時間はほぼ同一、classificationの精度は落ちているのが分かります。
まあ、人間と同じ感じですね。

モデル 学習時間 val_acc_gray カラー画像⇒ 学習時間 val_acc_color acc_g/acc_c
resnet18 1113 0.8747 1200 0.9046 96.7
mobilev2 1280 0.8867 1247 0.9196 96.4
densenet121 2457 0.9135 2531 0.9365 97.5
vgg16_bn 3863 0.8705 4190a 0.889a 97.9b
a; j <= 44: False (1.6MB) (128,128) batch32
b(参考データ);条件が異なるものの比

###まとめ
・Gray画像でのclassificationをFine-tuningしてみた
・精度は落ちるけど、ほぼ96-97%程度の精度が得られている

・今回の結果を利用してdensenet121とmobilev2を利用して色付けをやってみよう
###コード
苦労話を書くと、以下のネットワークモデルのコードをモデルひとつずつ動かしてネットワークモデルの構造を見ながら、上のような調整を実施している。

ネットワークモデルのコード

import torchvision.models as models
from torchsummary import summary
import torch
import pytorch_lightning as pl
from torch import nn
import torch.nn.functional as F
import numpy as np
import time

class customize_model(nn.Module):
  def __init__(self, input_size=128, sel_model='resnet18'):
    super(customize_model, self).__init__()
    
    ## Select model
    if sel_model == 'resnext50':
        self.model_ = models.resnext50_32x4d(pretrained=True)
        self.model_.conv1.weight = nn.Parameter(self.model_.conv1.weight.sum(dim=1).unsqueeze(1))
    elif sel_model == 'vgg16':
        model_0 = models.vgg16_bn(pretrained=True) 
        self.model_ = nn.Sequential(*list(model_0.children())[0]) #VGG16_bn 
        self.model_[0].weight = nn.Parameter(self.model_[0].weight.sum(dim=1).unsqueeze(1))
    elif sel_model == 'wide50':
        self.model_ = models.wide_resnet50_2(pretrained=True)
        self.model_.conv1.weight = nn.Parameter(self.model_.conv1.weight.sum(dim=1).unsqueeze(1))
    elif sel_model == 'mobilev2':
        self.model_ = models.mobilenet_v2(pretrained=True)
        self.model_.features[0][0].weight = nn.Parameter(self.model_.features[0][0].weight.sum(dim=1).unsqueeze(1))
    elif sel_model == 'densenet121':
        self.model_ = models.densenet121(pretrained=True)
        self.model_.features.conv0.weight = nn.Parameter(self.model_.features.conv0.weight.sum(dim=1).unsqueeze(1))
    else:
        self.model_ = models.resnet18(pretrained=True)
        self.model_.conv1.weight = nn.Parameter(self.model_.conv1.weight.sum(dim=1).unsqueeze(1))        
    
    for i, param in enumerate(self.model_.parameters()):
        param.requires_grad = True #False
        print(i, param.requires_grad)
    if sel_model =='vgg16':
        self.f_resnet = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=np.int(512*4*4), out_features=2048, bias=True), #wide_resnet50-2 512*256
            nn.ReLU(),
            nn.Dropout(p=0.2, inplace=False),
            nn.Linear(in_features=np.int(512*2*2), out_features=1000, bias=True), #wide_resnet50-2 512*256
            nn.Linear(in_features=np.int(1000), out_features=10, bias=True)
            )    
    else:
        self.f_resnet = nn.Sequential(
            nn.Linear(in_features=np.int(1000), out_features=10, bias=True)
  )   

  def forward(self, input):

    midlevel_features = self.model_(input)
    output = self.f_resnet(midlevel_features)
    return output

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #for gpu
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)
    pl.seed_everything(0)

    model = models.vgg16_bn(pretrained=False, num_classes=10)
    model = model.to(device) #for gpu
    print('Gray_model',model)

    dim = (3,128,128)
    summary(model,dim)
    
if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time)) 

学習時コード

import argparse
import time
import numpy as np

from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import torch.nn as nn
import torch.nn.functional as F

import torch
import torchvision
import torchvision.transforms as transforms

from torchvision.datasets import CIFAR10
import torch.optim as optim
from torchsummary import summary

import matplotlib.pyplot as plt
import tensorboard

from sam import SAM

from model_print import customize_model
#from net_cifar10 import Net_cifar10
#from net_vgg16 import VGG16

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.pause(3)
    plt.close()

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import EarlyStopping, ProgressBar,LearningRateMonitor

class MyPrintingCallback(Callback):

    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('Trainer is init now')
        
    def on_epoch_end(self, trainer, pl_module):
        #print("pl_module = ",pl_module)
        #print("trainer = ",trainer)
        print('')
        
    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')    

class LitProgressBar(ProgressBar):

    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.set_description('running validation ...')
        return bar
    
class Net(pl.LightningModule):
    def __init__(self, input_size =32, sel_model = 'resnet18', data_dir='./'):
        super().__init__()
        self.data_dir = data_dir

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (3, input_size, input_size)
        self.channels, self.width, self.height = self.dims
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((self.width, self.height)),
            transforms.Grayscale()
            #transforms.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
        ])
        
        #self.net = Net_cifar10()
        #self.net = VGG16()
        self.net = customize_model(input_size=input_size, sel_model = sel_model ) #'vgg16', 'wide50', 'mobilev2', 'densenet121'
        
        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        x, t =batch
        y = self.net(x)
        loss = F.cross_entropy(y, t)
        self.log('test_loss', loss, on_step = False, on_epoch = True)
        self.log('test_acc', self.val_acc(y, t), on_step = False, on_epoch = True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        preds = torch.argmax(y, dim=1)
        
        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_acc(y,t), prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        #optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        optimizer = optim.Adam(self.parameters(),lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
        #base_optimizer = torch.optim.SGD
        #optimizer = SAM(self.net.parameters(), base_optimizer, lr=0.1, momentum=0.9)
        #scheduler = {'scheduler': optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: 0.95 ** epoch)}
        scheduler = {'scheduler': optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)}
        # first forward-backward pass
        print("CF;Ir = ", optimizer.param_groups[0]['lr'])
        return [optimizer], [scheduler]
    
    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            cifar_full =CIFAR10(self.data_dir, train=True, transform=self.transform)
            n_train = int(len(cifar_full)*0.8)
            n_val = len(cifar_full)-n_train
            self.cifar_train, self.cifar_val = torch.utils.data.random_split(cifar_full, [n_train, n_val])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        #classes = tuple(np.linspace(0, 9, 10, dtype=np.uint8))
        classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        trainloader = DataLoader(self.cifar_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
        # get some random training images
        dataiter = iter(trainloader)
        images, labels = dataiter.next()
        # show images
        imshow(torchvision.utils.make_grid(images))
        # print labels
        print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
        return trainloader
    
    def val_dataloader(self):
        return DataLoader(self.cifar_val, shuffle=False, batch_size=32, num_workers=0)
    
    def test_dataloader(self):
        return DataLoader(self.cifar_test, shuffle=False, batch_size=32, num_workers=0)

   
def main():
    '''
    main
    '''
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #for gpu
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)
    pl.seed_everything(0)
    # model
    size_ = 32*4 #32*8
    sel_model_ = 'resnet18'   #'resnext50' #'densenet121'  #'wide50'  #'mobilev2' #'vgg16' #'resnet18'
    net = Net(input_size = size_, sel_model = sel_model_) #'vgg16', 'wide50', 'mobilev2', 'densenet121', 
    model = net.to(device)  #for gpu
    print(sel_model_ , model)
    summary(model,(1,size_,size_)) #(3,size_,size_)
    early_stopping = EarlyStopping('val_acc')  #('val_loss'
    bar = LitProgressBar(refresh_rate = 10, process_position = 1)
    #lr_monitor = LearningRateMonitor(logging_interval='step')

    trainer = pl.Trainer(gpus=1, max_epochs=10,callbacks=[MyPrintingCallback(), bar]) # progress_bar_refresh_rate=10
    trainer.fit(model)
    path_ = './categorize/'
    PATH = path_+'model_{}_cifar10_.ckpt'.format(sel_model_ )
    trainer.save_checkpoint(PATH)
    results = trainer.test()
    print(results)
    
if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))
5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?