0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

論文の勉強9 Xception

Posted at

Xceptionについて構造の説明と実装のメモ書きです。
ただし、論文すべてを見るわけでなく構造のところを中心に見ていきます。
勉強のメモ書き程度でありあまり正確に実装されていませんので、ご了承ください。

以下の論文について実装を行っていきます。

タイトル:Xception: Deep Learning with Depthwise Separable Convolutions

Xception

depthwise separable convolution

図に示すようにinputに対し1×1の畳み込み層でチャンネル数をoutput channelsに増やし、
そのチャンネルを重複のないセグメントに分けそれぞれ3×3の畳み込みを行い、
最後に各出力を結合します。
これを(depthwise) separable convolutionと呼びます。本来の「separable convolution」は3×3の畳み込みのあとに1×1の畳み込みを行いますが、層を積み重ねていくのでその違いは重要でないと考えます。
Xceptionでは1チャンネルごとに3×3の畳み込みを行います。

image.png

プーリング層のないシンプルなInception(上図)を考えたときに、このseparable convolutionで表現できる(下図)ことが分かります。

image.png
image.png

構造の詳細を見ます。
Ently flow、Middle flow、Exit flowの3つのブロックに分かれています。Middle flowは8回繰り返す。
residual接続を含みますが、separable convolutionの積み重ねとなっています。
すべてのConv層のあとにはBatchNormalizationが含まれます。

image.png

InceptionV1~3であったような途中の層で分類させるような構造は持ちません。

学習

SGDでmomentumを0.9とします。
学習率は初期値が0.045で2エポックごとに0.94をかけて小さくしていきます。

実装

keras

必要なライブラリのインポートをします。

import tensorflow.keras as keras
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Conv2D, Activation, MaxPooling2D, AveragePooling2D, Flatten, Dense, Dropout, GlobalAveragePooling2D, BatchNormalization, Add, SeparableConv2D
from keras.layers.merge import concatenate
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import LearningRateScheduler
from keras.datasets import cifar10
import numpy as np
import math
import cv2

Separable ConvolutionとBatchNormalizationとReLUのセットをまとめて定義します。
SeparableConv2Dが実装されているので利用します。

class conv_relu_bn(Model):
    def __init__(self, out_channels, act=True, inv=False):
        super().__init__()
        self.act = act
        self.inv = inv
    
        if act:
            self.relu = Activation("relu")
        self.conv = SeparableConv2D(out_channels, kernel_size=3, padding='same')
        self.bn = BatchNormalization()
    
    def call(self, x):
        if self.inv:
            if self.act:
                x = self.relu(x)
            x = self.conv(x)
            x = self.bn(x)
        else:
            x = self.conv(x)
            x = self.bn(x)
            if self.act:
                x = self.relu(x)
        return x

xception_blockの実装をします。
2つの畳み込み層とプーリングまたは畳み込み層を組み合わせます。

class xception_block(Model):
    def __init__(self, out_channels1, out_channels2=None, act=True, pool=True):
        super().__init__()
        if not out_channels2:
            out_channels2 = out_channels1
        
        self.conv1 = conv_relu_bn(out_channels1, act=act)
        self.conv2 = conv_relu_bn(out_channels2)
        if pool:
            self.layer3 = MaxPooling2D(pool_size=3, strides=2, padding='same')
        else:
            self.layer3 = conv_relu_bn(out_channels2)
            
    def call(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.layer3(x)
        return x

Ently flowの実装をします。

class Ently_flow(Model):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2D(32, kernel_size=3, strides=2, padding='same')
        self.relu1 = Activation("relu")
        
        self.conv2 = Conv2D(64, kernel_size=3, padding='same')
        self.relu2 = Activation("relu")
        
        self.block1 = xception_block(out_channels1=128, act=False)
        self.block2 = xception_block(out_channels1=256)
        self.block3 = xception_block(out_channels1=728)

        self.res1 = Conv2D(128, kernel_size=1, strides=2, padding='same')
        self.res2 = Conv2D(256, kernel_size=1, strides=2, padding='same')
        self.res3 = Conv2D(728, kernel_size=1, strides=2, padding='same')
        
        self.add1 = Add()
        self.add2 = Add()
        self.add3 = Add()
        
    def call(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        
        res1 = self.res1(x)
        res2 = self.res2(res1)
        res3 = self.res3(res2)
        
        out1 = self.block1(x)
        out1 = self.add1([out1, res1])
        out2 = self.block2(out1)
        out2 = self.add2([out2, res2])
        out3 = self.block3(out2)
        out3 = self.add3([out3, res3])
        
        return out3

Middle blockの実装をします。
8回繰り返すのでそれも含めて実装します。

class Middle_block(Model):
    def __init__(self):
        super().__init__()
        self.block = xception_block(out_channels1=728, pool=False)
        self.add = Add()
        
    def call(self, x):
        res = x
        out = self.block(x)
        out = self.add([out, res])
        
        return out
class Middle_flow(Model):
    def __init__(self):
        super().__init__()
        self.xcept_layers = [Middle_block() for _ in range(8)]
        
    def call(self, x):
        for layer in self.xcept_layers:
            x = layer(x)
        
        return x

Exit flowを実装します。

class Exit_flow(Model):
    def __init__(self):
        super().__init__()
        
        self.block = xception_block(out_channels1=728, out_channels2=1024, pool=True)
        
        self.res = Conv2D(1024, kernel_size=1, strides=2, padding='same')
        self.add = Add()
        
        self.conv1 = conv_relu_bn(out_channels=1536, inv=True)
        self.conv2 = conv_relu_bn(out_channels=2048, inv=True)
        
        self.pool = GlobalAveragePooling2D()
        
        self.fc1 = Dense(1024)
        self.fc2 = Dense(512)
        self.fc3 = Dense(10, activation = 'softmax')
    
    def call(self, x):
        res = x
        res = self.res(res)
        
        x = self.block(x)
        x = self.add([x, res])
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool(x)
        
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

Xception本体の実装をします。

class Xception(Model):
    def __init__(self):
        super().__init__()
        self.entry = Ently_flow()
        self.middle = Middle_flow()
        self.exit = Exit_flow()
        
    def call(self, x):
        x = self.entry(x)
        x = self.middle(x)
        x = self.exit(x)
        return x

モデルの構造を確認します。

model = Xception()
model.build((None, 299, 299, 3))  # build with input shape.
dummy_input = Input(shape=(299, 299, 3))  # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=model.call(dummy_input))
model_summary.summary()

image.png

学習の設定をします。

epochs = 100
initial_lrate = 0.0045

def decay(epoch, steps=100):
    initial_lrate = 0.0045
    drop = 0.94
    epochs_drop = 2
    lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))
    return lrate

sgd = SGD(lr=0.0045, momentum=0.9, decay=1e-4)

lr_sc = LearningRateScheduler(decay, verbose=1)

model = MobileNetV1()
model.compile(loss=['categorical_crossentropy'], optimizer=sgd, metrics=['accuracy'])

pytorch

必要なライブラリのインポートをします。

import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
import torch.nn.functional as F

import pytorch_lightning as pl
from torchmetrics import Accuracy as accuracy

ここでは、Separable Convolutionを定義します。

class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.pointwise = nn.Conv2d(out_channels, out_channels, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

これ以降はkerasと同様の実装を行います。
畳み込み層

class conv_relu_bn(nn.Module):
    def __init__(self, in_channels, out_channels, act=True, inv=False):
        super().__init__()
        self.act = act
        self.inv = inv
        
        if act:
            self.relu = nn.ReLU(True)
        self.conv = SeparableConv2d(in_channels, out_channels)
        self.bn = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        if self.inv:
            if self.act:
                x = self.relu(x)
            x = self.conv(x)
            x = self.bn(x)
        else:
            x = self.conv(x)
            x = self.bn(x)
            if self.act:
                x = self.relu(x)
        return x

xception_block

class xception_block(nn.Module):
    def __init__(self, in_channels, out_channels1, out_channels2=None, act=True, pool=True):
        super().__init__()
        if not out_channels2:
            out_channels2 = out_channels1
        
        self.conv1 = conv_relu_bn(in_channels=in_channels, out_channels=out_channels1, act=act)
        self.conv2 = conv_relu_bn(in_channels=out_channels1, out_channels=out_channels2)
        if pool:
            self.layer3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        else:
            self.layer3 = conv_relu_bn(in_channels=out_channels2, out_channels=out_channels2)
            
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.layer3(x)
        return x

Ently flow

class Ently_flow(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.relu1 = nn.ReLU(True)
        
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU(True)
        
        self.block1 = xception_block(in_channels=64, out_channels1=128, act=False)
        self.block2 = xception_block(in_channels=128, out_channels1=256)
        self.block3 = xception_block(in_channels=256, out_channels1=728)

        self.res1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, stride=2, padding=0)
        self.res2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, stride=2, padding=0)
        self.res3 = nn.Conv2d(in_channels=256, out_channels=728, kernel_size=1, stride=2, padding=0)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        
        res1 = self.res1(x)
        res2 = self.res2(res1)
        res3 = self.res3(res2)
        
        out1 = self.block1(x)
        out1 = out1 + res1
        out2 = self.block2(out1)
        out2 = out2 + res2
        out3 = self.block3(out2)
        out3 = out3 + res3
        
        return out3

Middle flow

class Middle_block(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = xception_block(in_channels=728, out_channels1=728, pool=False)
        
    def forward(self, x):
        res = x
        out = self.block(x)
        out = out + res
        
        return out
class Middle_flow(nn.Module):
    def __init__(self):
        super().__init__()
        self.xcept_layers = nn.Sequential(*[Middle_block() for _ in range(8)])
        
    def forward(self, x):
        for layer in self.xcept_layers:
            x = layer(x)
        
        return x

Exit flow

class Exit_flow(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.block = xception_block(in_channels=728, out_channels1=728, out_channels2=1024, pool=True)
        
        self.res = nn.Conv2d(in_channels=728, out_channels=1024, kernel_size=1, stride=2, padding=0)
        
        self.conv1 = conv_relu_bn(in_channels=1024, out_channels=1536, inv=True)
        self.conv2 = conv_relu_bn(in_channels=1536, out_channels=2048, inv=True)
        
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 10)
    
    def forward(self, x):
        res = x
        res = self.res(res)
        
        x = self.block(x)
        x = x + res
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

Xceptionの本体を実装する。

class Xception(nn.Module):
    def __init__(self):
        super().__init__()
        self.entry = Ently_flow()
        self.middle = Middle_flow()
        self.exit = Exit_flow()
        
    def forward(self, x):
        x = self.entry(x)
        x = self.middle(x)
        x = self.exit(x)
        return x

構造の確認を行う。

from torchsummary import summary

summary(Xception(), (3,299,299))

image.png

class XceptionTrainer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = Xception()
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch 
        #x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'loss': loss, 'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        #x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'y_hat':y_hat, 'y':y, 'batch_loss': loss.item()*x.size(0)}
    
    def test_step(self, batch, batch_nb):
        x, y = batch
        #x, y = x.to(device), y.to(device)
        y_hat = self.forward(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        y_label = torch.argmax(y_hat, dim=1)
        acc = accuracy()(y_label, y)
        return {'test_loss': loss, 'test_acc': acc}
    
    def training_epoch_end(self, train_step_output):
        y_hat = torch.cat([val['y_hat'] for val in train_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in train_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in train_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy()(preds, y)
        self.log('train_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('train_acc', acc, prog_bar=True, on_epoch=True)
        
        print('---------- Current Epoch {} ----------'.format(self.current_epoch + 1))
        print('train Loss: {:.4f} train Acc: {:.4f}'.format(epoch_loass, acc))
    
    def validation_epoch_end(self, val_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in val_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in val_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in val_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy()(preds, y)
        self.log('val_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True)
        
        print('valid Loss: {:.4f} valid Acc: {:.4f}'.format(epoch_loss, acc))
    
    # New: テストデータに対するエポックごとの処理
    def test_epoch_end(self, test_step_outputs):
        y_hat = torch.cat([val['y_hat'] for val in test_step_outputs], dim=0)
        y = torch.cat([val['y'] for val in test_step_outputs], dim=0)
        epoch_loss = sum([val['batch_loss'] for val in test_step_outputs]) / y_hat.size(0)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy()(preds, y)
        self.log('test_loss', epoch_loss, prog_bar=True, on_epoch=True)
        self.log('test_acc', acc, prog_bar=True, on_epoch=True)
        
        print('test Loss: {:.4f} test Acc: {:.4f}'.format(epoch_loss, acc))
        
    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.045, eps=1.0, momentum=0.9, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.94)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}

これで実装を終わります。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?