1. Radley

    Posted

    Radley
Changes in title
+【実装】あたらしい活性化関数「FReLU」をPytorch CIFER10 Turorialに組み込んでみる
Changes in tags
Changes in body
Source | HTML | Preview
@@ -0,0 +1,172 @@
+#FReLUとは?
+
+FReLU (Funnel Activation)はECCV2020で発表された、**画像認識に特化した活性化関数**です。よく使われる活性化関数には、Sigmoid、ReLU、Swish、Mishなどがありますが、画像分類やセマンティックセグメンテーションなどの画像処理においてはこれらの上位互換という位置付けです。
+
+>原論文: "Funnel Activation for Visual Recognition", Ma, N., Zhang, X., Sun, J. (ECCV'20)
+>公式実装: FunnelAct (MegEngine)
+
+FReLUを数式で表すと、
+
+```math
+y=max(x,𝕋(x))
+```
+
+となります。もとのピクセル値と、depthwise畳み込みで得られたピクセル値を比較して大きいものを選択する、という仕組みです。すなわち、周囲のピクセル値が大きい部分は周囲に値に影響されて大きくなります。
+
+![image.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/620139/24d21bb7-0299-41bb-9b57-49c722f7b00c.png)
+
+#FReRUの実装
+FReLUの実装は、以下の通りです。
+
+```python:FReLU
+import torch
+import torch.nn as nn
+class FReLU(nn.Module):
+ def __init__(self, in_c, k=3, s=1, p=1):
+ super().__init__()
+ self.f_conv = nn.Conv2d(in_c, in_c, kernel_size=k,stride=s, padding=p,groups=in_c)
+ self.bn = nn.BatchNorm2d(in_c)
+
+ def forward(self, x):
+ tx = self.bn(self.f_conv(x))
+ out = torch.max(x,tx)
+ return out
+```
+参考: https://qiita.com/omiita/items/bfbba775597624056987
+
+これを、Pytorch tutorialのCIFER10(https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)
+に組み込んでいきます。CIFER10は、60000枚、32ピクセルx32ピクセル RGBの3チャンネル、airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truckの10個のクラスラベルからなる、画像分類練習用のデータセットです。
+![image.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/620139/2e7e5b3b-b8a4-fcea-9fba-baafdd444969.png)
+このチュートリアルでは、初期のディープニューラルネットワークである、LeNetが用いられています。元のコードがこちら。
+
+```python:CIFER10_Pytorch_Tutorial
+import torch.nn as nn
+import torch.nn.functional as F
+
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(3, 6, 5)
+ self.pool = nn.MaxPool2d(2, 2)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = self.pool(F.relu(self.conv1(x)))
+ x = self.pool(F.relu(self.conv2(x)))
+ x = x.view(-1, 16 * 5 * 5)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+net = Net()
+net = net.to(device)
+```
+「ReLU」が活性化関数として用いられていますので**「FReLU」に置き換えてみましょう**。FReLUは、まだtorch.nnのモジュール内に組み込まれていないため、F.reluといった使い方ができません。まず、上記のスクリプトをF.ReLUを使わない形に書き換えます。
+
+```python:CIFER10_Pytorch_Tutorial_2
+#ReLU
+import torch.nn as nn
+import torch.nn.functional as F
+
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.relu = nn.ReLU()
+ self.pool = nn.MaxPool2d(2, 2)
+ self.conv1 = nn.Conv2d(3, 6, 5)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.pool(x)
+ x = self.conv2(x)
+ x = self.relu(x)
+ x = self.pool(x)
+ x = x.view(-1, 16 * 5 * 5)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.relu(x)
+ x = self.fc3(x)
+ return x
+
+
+net = Net()
+net = net.to(device)
+```
+`__init__(self)`でnn.ReLU()を定義した後に、`forward(self,x)`にself.relu(x)を配置していきます。
+このReLUの部分を、FReLUに置き換えていきます。
+
+```python:CIFER10_Pytorch_Tutorial_3
+#FReLU
+import torch.nn as nn
+import torch.nn.functional as F
+
+class FReLU(nn.Module):
+ def __init__(self, in_c, k=3, s=1, p=1):
+ super().__init__()
+ self.f_cond = nn.Conv2d(in_c, in_c, kernel_size=k,stride=s, padding=p,groups=in_c)
+ self.bn = nn.BatchNorm2d(in_c)
+
+ def forward(self, x):
+ tx = self.bn(self.f_cond(x))
+ out = torch.max(x,tx)
+ return out
+
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.relu = nn.ReLU()
+ self.pool = nn.MaxPool2d(2, 2)
+ self.conv1 = nn.Conv2d(3, 6, 5)
+ self.frelu1 = FReLU(6)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+ self.frelu2 = FReLU(16)
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.frelu1(x)
+ x = self.pool(x)
+ x = self.conv2(x)
+ x = self.frelu2(x)
+ x = self.pool(x)
+ x = x.view(-1, 16 * 5 * 5)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.relu(x)
+ x = self.fc3(x)
+ return x
+
+ちなみに、
+net = Net()
+net = net.to(device)
+```
+
+注意点としては、**FReLUのチャンネル数は、直前のConv2dのout_channelと同一です。**
+あと、x.viewで画像を引き延ばした後のReLUは、置換できないですね。
+
+#Google Colabでの実装を載せておきます。
+https://colab.research.google.com/github/ykitaguchi77/AdvancedPytorch_Colab/blob/master/Pytorch_FReLU_CIFER10.ipynb
+
+活性化関数にReLUを用いたスクリプト(バッチサイズは4)で2epoch回した結果がこちら。正解率は53%。
+![image.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/620139/fa165fcd-4f29-0d03-5628-e5129a5e282b.png)
+FReLUに置き換えた結果がこちら。Lossの減りが早い感じがします。正解率も60%にアップしました。
+![image.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/620139/e376382b-0986-4316-724a-77a7a3783d39.png)
+
+
+
+ちなみに、**ResNetへの実装**はこちら。行っていることは、上記と同じです。
+https://github.com/nekitmm/FunnelAct_Pytorch
+ResNetでは、ReLUをFReLUに置き換えるだけでも精度が上がることが証明されています。便利ですね。