#FReLUとは?
FReLU (Funnel Activation)はECCV2020で発表された、画像認識に特化した活性化関数です。よく使われる活性化関数には、Sigmoid、ReLU、Swish、Mishなどがありますが、画像分類やセマンティックセグメンテーションなどの画像処理においてはこれらの上位互換という位置付けです。
原論文: "Funnel Activation for Visual Recognition", Ma, N., Zhang, X., Sun, J. (ECCV'20)
公式実装: FunnelAct (MegEngine)
FReLUを数式で表すと、
y=max(x,𝕋(x))
となります。もとのピクセル値と、depthwise畳み込みで得られたピクセル値を比較して大きいものを選択する、という仕組みです。すなわち、周囲のピクセル値が大きい部分は周囲に値に影響されて大きくなります。
#FReRUの実装
FReLUの実装は、以下の通りです。
import torch
import torch.nn as nn
class FReLU(nn.Module):
def __init__(self, in_c, k=3, s=1, p=1):
super().__init__()
self.f_conv = nn.Conv2d(in_c, in_c, kernel_size=k,stride=s, padding=p,groups=in_c)
self.bn = nn.BatchNorm2d(in_c)
def forward(self, x):
tx = self.bn(self.f_conv(x))
out = torch.max(x,tx)
return out
参考: https://qiita.com/omiita/items/bfbba775597624056987
これを、Pytorch tutorialのCIFER10(https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)
に組み込んでいきます。CIFER10は、60000枚、32ピクセルx32ピクセル RGBの3チャンネル、airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truckの10個のクラスラベルからなる、画像分類練習用のデータセットです。
このチュートリアルでは、初期のディープニューラルネットワークである、LeNetが用いられています。元のコードがこちら。
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
net = net.to(device)
「ReLU」が活性化関数として用いられていますので**「FReLU」に置き換えてみましょう**。FReLUは、まだtorch.nnのモジュール内に組み込まれていないため、F.reluといった使い方ができません。まず、上記のスクリプトをF.ReLUを使わない形に書き換えます。
#ReLU
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(-1, 16 * 5 * 5)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
net = Net()
net = net.to(device)
__init__(self)
でnn.ReLU()を定義した後に、forward(self,x)
にself.relu(x)を配置していきます。
このReLUの部分を、FReLUに置き換えていきます。
#FReLU
import torch.nn as nn
import torch.nn.functional as F
class FReLU(nn.Module):
def __init__(self, in_c, k=3, s=1, p=1):
super().__init__()
self.f_cond = nn.Conv2d(in_c, in_c, kernel_size=k,stride=s, padding=p,groups=in_c)
self.bn = nn.BatchNorm2d(in_c)
def forward(self, x):
tx = self.bn(self.f_cond(x))
out = torch.max(x,tx)
return out
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.conv1 = nn.Conv2d(3, 6, 5)
self.frelu1 = FReLU(6)
self.conv2 = nn.Conv2d(6, 16, 5)
self.frelu2 = FReLU(16)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.conv1(x)
x = self.frelu1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.frelu2(x)
x = self.pool(x)
x = x.view(-1, 16 * 5 * 5)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
ちなみに、
net = Net()
net = net.to(device)
注意点としては、FReLUのチャンネル数は、直前のConv2dのout_channelと同一です。
あと、x.viewで画像を引き延ばした後のReLUは、置換できないですね。
#Google Colabでの実装を載せておきます。
https://colab.research.google.com/github/ykitaguchi77/AdvancedPytorch_Colab/blob/master/Pytorch_FReLU_CIFER10.ipynb
活性化関数にReLUを用いたスクリプト(バッチサイズは4)で2epoch回した結果がこちら。正解率は53%。
FReLUに置き換えた結果がこちら。Lossの減りが早い感じがします。正解率も60%にアップしました。
ちなみに、ResNetへの実装はこちら。行っていることは、上記と同じです。
https://github.com/nekitmm/FunnelAct_Pytorch
ResNetでは、ReLUをFReLUに置き換えるだけでも精度が上がることが証明されています。便利ですね。