はじめに
活性関数として $ReLU^2$ が良いらしいと聞いたので試してみました。
$ReLU^2$ は単純に ReLU を適用した後に二乗する処理になります。
斜め読みした論文は以下の通りです。
私にはスパース性がうんぬんという話はわからないので精度のみを確認しました。
実験のたたき台
以下のレポジトリのCifar10を94%識別するスクリプトをたたき台としました。
実行回数を200から10に減らして実行しました。
正答率は平均 94.09% でした。実行時間は私の使った環境では8.5~8.8秒程度でした。
単純に $ReLU^2$ を使う
次に単純に活性関数を $ReLU^2$ に置き換えて実行してみました。元のモデルの活性関数は GELU です。
コードは以下になります。
正答率はチャンスレベルの10%程度になり、学習に失敗しました。
冷静に考えると2乗しているので、入力が1以下の場合勾配が小さくなり、1以上の場合大きくなります。つまり、積層すると勾配消失と勾配爆発のどちらも発生する可能性が高いです。
スキップ接続の導入
たたき台にしたスクリプトでは単純に畳み込み層、正規化層、活性関数を積層した構造になっています。
色々試行錯誤したところ、スキップ接続を導入した上で、残差計算を行うモジュールの中で1回だけ $ReLU^2$ を挿入すると学習に成功するようになりました。
そのコードがこちらです。
正答率は92.64%と若干悪くなりました。
スキップ接続の影響と $ReLU^2$ の影響を切り分けるため、スキップ接続を導入しつつ、活性関数をGELUのままのものも実行しました。その場合の正答率は 92.93% となり、$ReLU^2$ の場合よりも良い結果になりました。
学習が成功するまでに試した結果を以下にまとめます。
手法 | 正答率 | 実行時間 |
---|---|---|
変更前 | 94.09% | 8.5~8.8秒 |
$ReLU^2$ の利用 | 10% | 8.0~8.1秒 |
スキップ接続の導入 | 92.93% | 9.7~9.9秒 |
スキップ接続の導入 + $ReLU^2$の利用 | 10% | 9.1~9.2秒 |
スキップ接続の導入 + 一つおきに $ReLU^2$を利用 | 92.64% | 9.5~9.7秒 |
地味に GELU よりも高速なようですが、正答率は低下しています。
Concat ReLU を思い出す
活性関数の ReLU の亜種として Concat ReLU というものがあります。
こちらは単純に x と -x を concat した上で ReLU を適用するというものです。pytorchのコードで書くといかのようになります。
import torch
import torch.nn.functional as F
def concat_relu(x, dim=-1):
return F.relu(torch.concat([x, -x], dim=dim)
この処理によりチャンネル数は倍に増加するため、直前の計算での出力を半分にすることが可能で、計算量の削減が期待できます。GELUなど若干負の成分を利用する活性関数では使用していいものか悩みますが、$ReLU^2$の場合、負の値は ReLU 同様まったく利用しないため、気持ちよく利用できます。
$ReLU^2$ に符号を逆にしたものを concat する処理を加えたものを$concat ReLU^2$と呼ぶこととします。
pytorch で実装した場合は以下のようになります。
import torch
import torch.nn.functional as F
def concat_relu2(x, dim=-1):
return F.relu(torch.concat([x, -x], dim=dim)**2
「スキップ接続の導入 + 一つおきに $ReLU^2$ を利用」の $ReLU^2$ を $concat ReLU^2$に置き換えて実行して見ました。直前の層の出力は半分にしています。
スクリプトは以下になります。
結果は以下の通りです。
手法 | 正答率 | 実行時間 |
---|---|---|
スキップ接続の導入 + 一つおきに $ReLU^2$を利用 | 92.64% | 9.5~9.7秒 |
スキップ接続の導入 + 一つおきに $Concat ReLU^2$を利用 | 91.72% | 7.5~7.7秒 |
正答率はかなり低下していますが、実行時間はかなり短縮できています。これが割に合う変更なのかはなんとも言えません。
負の値を利用してみる
どうせ2乗してるのだから、別に負の値を切り捨てなくとも非線形性は保持できるなと考えて次のような活性関数を考えてみました。ここでは sign square と呼ぶこととします。
import torch
import torch.nn.functional as F
def sign_square(x):
return torch.sign(x) * torch.square(x)
$ReLU^2$をsign square に置き換えて実行した結果を以下に示します。
手法 | 正答率 | 実行時間 |
---|---|---|
スキップ接続の導入 + 一つおきに $ReLU^2$を利用 | 92.64% | 9.5~9.7秒 |
スキップ接続の導入 + 一つおきに sign squareを利用 | 92.46% | 9.7~10.3秒 |
若干正答率が下がった上に実行時間も大きくなりました。
おわりに
$ReLU^2$について試してみました。
実験の結果、原則正の活性値については勾配の大きさを変更しない ReLU 系の活性関数に比べて、適用できる場所が限られているピーキーな活性関数であることがわかりました。ただし、昨今、使われているモデルでは活性関数が使われるのは、FFNの残差ブロック中に一回であるため、$ReLU^2$は適用可能なモデルがほとんどと考えています。
また、正答率については今回の実験ではGELUに比べて若干性能が低下するという結果になりました。
計算量を削減できる $concat ReLU^2$ については結構可能性があるんじゃないかなと個人的には思っています。
以上になります。