Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
7
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

【実装】あたらしい活性化関数「FReLU」をPytorch CIFER10 Turorialに組み込んでみる

FReLUとは?

FReLU (Funnel Activation)はECCV2020で発表された、画像認識に特化した活性化関数です。よく使われる活性化関数には、Sigmoid、ReLU、Swish、Mishなどがありますが、画像分類やセマンティックセグメンテーションなどの画像処理においてはこれらの上位互換という位置付けです。

原論文: "Funnel Activation for Visual Recognition", Ma, N., Zhang, X., Sun, J. (ECCV'20)
公式実装: FunnelAct (MegEngine)

FReLUを数式で表すと、

y=max(x,𝕋(x))

となります。もとのピクセル値と、depthwise畳み込みで得られたピクセル値を比較して大きいものを選択する、という仕組みです。すなわち、周囲のピクセル値が大きい部分は周囲に値に影響されて大きくなります。

image.png

FReRUの実装

FReLUの実装は、以下の通りです。

FReLU
import torch
import torch.nn as nn
class FReLU(nn.Module):
    def __init__(self, in_c, k=3, s=1, p=1):
        super().__init__()
        self.f_conv = nn.Conv2d(in_c, in_c, kernel_size=k,stride=s, padding=p,groups=in_c)
        self.bn = nn.BatchNorm2d(in_c)

    def forward(self, x):
        tx = self.bn(self.f_conv(x))
        out = torch.max(x,tx)
        return out

参考: https://qiita.com/omiita/items/bfbba775597624056987

これを、Pytorch tutorialのCIFER10(https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)
に組み込んでいきます。CIFER10は、60000枚、32ピクセルx32ピクセル RGBの3チャンネル、airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truckの10個のクラスラベルからなる、画像分類練習用のデータセットです。
image.png
このチュートリアルでは、初期のディープニューラルネットワークである、LeNetが用いられています。元のコードがこちら。

CIFER10_Pytorch_Tutorial
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
net = net.to(device)

「ReLU」が活性化関数として用いられていますので「FReLU」に置き換えてみましょう。FReLUは、まだtorch.nnのモジュール内に組み込まれていないため、F.reluといった使い方ができません。まず、上記のスクリプトをF.ReLUを使わない形に書き換えます。

CIFER10_Pytorch_Tutorial_2
#ReLU
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


net = Net()
net = net.to(device)

__init__(self)でnn.ReLU()を定義した後に、forward(self,x)にself.relu(x)を配置していきます。
このReLUの部分を、FReLUに置き換えていきます。

CIFER10_Pytorch_Tutorial_3
#FReLU
import torch.nn as nn
import torch.nn.functional as F

class FReLU(nn.Module):
    def __init__(self, in_c, k=3, s=1, p=1):
        super().__init__()
        self.f_cond = nn.Conv2d(in_c, in_c, kernel_size=k,stride=s, padding=p,groups=in_c)
        self.bn = nn.BatchNorm2d(in_c)

    def forward(self, x):
        tx = self.bn(self.f_cond(x))
        out = torch.max(x,tx)
        return out

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.frelu1 = FReLU(6)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.frelu2 = FReLU(16)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.frelu1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.frelu2(x)
        x = self.pool(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

ちなみに
net = Net()
net = net.to(device)

注意点としては、FReLUのチャンネル数は、直前のConv2dのout_channelと同一です。
あと、x.viewで画像を引き延ばした後のReLUは、置換できないですね。

Google Colabでの実装を載せておきます。

活性化関数にReLUを用いたスクリプト(バッチサイズは4)で2epoch回した結果がこちら。正解率は53%。
image.png
FReLUに置き換えた結果がこちら。Lossの減りが早い感じがします。正解率も60%にアップしました。
image.png

ちなみに、ResNetへの実装はこちら。行っていることは、上記と同じです。
https://github.com/nekitmm/FunnelAct_Pytorch
ResNetでは、ReLUをFReLUに置き換えるだけでも精度が上がることが証明されています。便利ですね。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
7
Help us understand the problem. What are the problem?