0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Depthwise Convolution の初期化について

Last updated at Posted at 2022-05-24

はじめに

Depthwise Convolution を使った実験を以前行いました。

その中で Depthwise Convolution についていくつか気になる点があったため、まとめます。

Xavier, He の初期化に不具合が入り込みやすい

Xavier の初期化や He の初期化と呼ばれる順伝播/逆伝播時の出力の分散が入力の分散と同じになるような初期化方法(以後、正確ではありませんが He の初期化と呼ぶこととします)を利用した場合に、pytorch で確認した限り適切に初期化されないケースがありました。
原因は初期化メソッドが重み行列のみを受け取ることにあるため、他のライブラリでも発生しやすい不具合と考えています。

まず、最初にどう動作すべきかについて説明を行います。
Depthwise Convolution は、チャンネル数が1の Convolution を入力チャンネル数の分並列に並べたものといえます。
したがって、He の初期化を行う場合に、実際の入力チャンネル数/出力チャンネル数とは無関係に、チャンネル数を1と考えて初期化する必要があります。

例えば、7x7の Depthwise Convolution の場合には、順伝播/逆伝播の場合ともに標準偏差は $\frac{1}{\sqrt{1*7*7}} \fallingdotseq 0.142$ となります。

次に実際に pytorch にて初期化を行い、標準偏差の値を確認します。確認した pytorch のバージョンは 1.11.0 です。
確認には以下のスクリプトを利用しました。

import torch
channels = 128
conv = torch.nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels)
print("size = {}".format(conv.weight.size()))

torch.nn.init.kaiming_normal_(conv.weight, mode='fan_in', nonlinearity='linear')
print("std on fan_in = {}".format(conv.weight.std().item())) 
 
torch.nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='linear')
print("std on fan_out = {}".format(conv.weight.std().item()))

出力は以下のとおりです。

size = torch.Size([128, 1, 7, 7])
std on fan_in = 0.1427820324897766
std on fan_out = 0.012765130959451199

fan_in を利用した場合には期待通りになっていますが、fan_out を利用した場合には期待通りになっていません。
重みのサイズを元に初期化する場合には、1入力複数出力の Convolution と Depthwise Convolution は区別できないため、インターフェイスを変えない限りは正しく初期化できません。

相対的学習率の問題

次に、正しく He の初期化を行ったとして、相対学習率の問題が発生すると考えています。

例えば、3x3 の Depthwise Convolution に対する He の初期化では重みの標準偏差が $ \frac{1}{\sqrt{1*3*3}} \fallingdotseq 0.333 $ となります。
一方、入出力のチャンネル数が 512 の Pointwise Convolution の場合は $ \frac{1}{\sqrt{512}} \fallingdotseq 0.0441 $ となります。

この場合、Depthwise Convolution の重みは Pointwise Convolution 重みの約7.5倍ほどの標準偏差となり、同じ更新幅を利用する場合、相対的に学習率が7.5分の1になります。

対応方法としては例えば ConvNeXt では一律標準偏差0.2で重みを初期化しています。
また、単純に LARS/LAMB を使うという解決方法もあります。

ランダムに初期化する必要があるのか

Depthwise Convolution では同じ入力の組み合わせからの出力は常にひとつになります。したがって、必ずしも乱数で初期化する必要はないと考えました。
そこで、PoolFormer を参考に、平均プーリングした結果から中心の要素を引くような演算を最初は行うように初期化してみます。

具体的なコードは下記のとおりです。なお、最後に重みのスケールが 0.2 になるように調整を行っています。

with torch.no_grad():
    _, _, kh, kw = dwconv.weight.size()
    dwconv.weight.data.fill_(1. / (kh * kw))
    dwconv.weight.data[:, :, kh//2, kw//2].sub_(1.0)
    dwconv.weight.data.mul_((dwconv.weight**2).mean().rsqrt() * 0.2)

以前行った Cifar-100 での実験の GELU を利用した ConvNeXt を模したネットワークで、初期化方法のみを変更して学習を行ってみました。

学習損失を以下に示します。X軸は学習時間(秒)です。

train_loss.png

検証時正答率を以下に示します。X軸は学習時間(秒)です。

val_accuracy.png

有意かと言われると微妙なところですが、若干性能が良くなりました。

おわりに

特に結論を出せるような話でもないのですが、Depthwise Convolution について考えてみました。
正直、FLOPSほどGPU/TPUで高速化されないため、あまり使われていない感はありますが、まだまだ考察の余地はありそうです。

以上

追記

この記事を投稿後に PoolFormer にて Average Pooling から入力値を引く処理が入っているのは Skip Connection の影響を排除するためという記載を見ました。
この記事の実験では Skip Connection の加算が Depthwise Convolution の直後にあるわけではないので、入力値を引く動作がないほうがよいのではないかと考え、確認を行いました。

実際の初期化コードは以下のとおりです。重みのスケールは調整しています。

if isinstance(m, DepthwiseConv2d):
    _, _, kh, kw = m.weight.size()
    m.weight.data.fill_(1. / (kh * kw))
    m.weight.data.mul_((m.weight**2).mean().rsqrt() * 0.2)

学習損失を以下に示します。X軸は epochs に変更しました。

train_loss.png

検証時正答率を以下に示します。同様にX軸は epochs です。

val_accuracy.png

Average Pooling を模した初期化では性能は悪化しました。

PoolFormer では単純な Average Pooling ではなく、以下の形になっています。

$$
\begin{equation}
y = x + \alpha \times AveragePooling(Normalization(x))
\end{equation}
$$

なお、PoolFormer の実装を見た限り $\alpha$ は初期値 1e-5 でチャンネルごとに異なる値を取れる学習可能なパラメータになっています。

この形では初期値はむしろ入力値のみになり、Normalization を除いて考えると $\alpha$ の学習により AvgPooling に近づくように変化することができる構造になっています。

この形が性能に貢献しているのかなと想像しています。

以上です。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?