2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

フルスクラッチでバッチ正規化を理解する

Last updated at Posted at 2023-12-08

この記事でやること

E資格の勉強中にミニバッチ学習のバッチ正規化で躓きました。みんな大好き"ゼロつく"でもあまり詳しく取り上げられていないため、この記事では数式の理解とスクラッチの実装を初学者レベルで追っていきたいと思います。

バッチ正規化とは

ざっくり言うと、「学習の際にミニバッチを平均=0、標準偏差=1となるように正規化すること」です。出力が適度にばらけることで勾配消失などの問題を予防することができます。MITのDeep Learning Bookによると正確には、最適化(Optimization)ではなく再パラメータ化(Reparametrization)という手法の一つのようです。複数の隠れ層を持つ深層学習モデルで、各層の活性化層前にバッチ正規化層を挟みます)。モデル全体のイメージは下図のようになります。

バッチ正規化1.png

それではバッチ正規化の式を見ていきます。
$$ \bf{X}' = \frac{\bf{X} - \mu}{\sigma} $$
ここで$ \bf{X}' $はバッチ正規化後のパラメータ(次の層への入力)で、$ \bf{X} $はバッチ正規化前のパラメータです。$ \mu , \sigma$はそれぞれ、平均、標準偏差となります。このことにより逆伝播を考える際の勾配が、各層の平均や標準偏差に影響されにくくなり、結果として勾配の中に含まれる0要素がもたらす勾配消失問題を起きにくくすることができます。また、標準偏差$\sigma$を求める際には、分散$\sigma^2$に微小量$\delta$を足すことでアンダーフローを回避します。
$$ \begin{gather} \sigma = \sqrt{\delta + \sigma^2} \cr
\sigma^2 = \frac{1}{m}\sum_{1 \le i\le m}{(\bf{x} - \mu)_i^2} \end{gather}$$

学習と推論

ここで平均、標準偏差にどのような値を用いるかということですが、これはミニバッチ学習であるため、学習時にはデータセット全体の平均値や標準偏差を把握することができません。そこで、バッチごとの平均、標準偏差からデータ全体の平均、標準偏差を見積もることを考えます。利用するのはモーメンタム(慣性)の考え方です。各バッチの平均、標準偏差を計算した後に、すでに学習が済んでいる他のバッチの平均、標準偏差との加重平均を取ります。これによりバッチ特有の外れ値等の影響を弱めることができます。学習時にはバッチごとの平均、標準偏差を使い、推論時にはモーメンタムを導入して求めた平均と標準偏差を使います。

バッチ正規化2.png

$$ \begin{gather} \mu = m\mu_{pre} + (1 - m)\mu_{batch} \cr
\sigma^2 = m\sigma^2_{pre} + (1 - m)\sigma^2_{batch} \end{gather} $$
$m$がモーメンタムとなります。古い情報に重きを置くために$0.80\le m < 1.0$くらいが良いかと思います。また、ここでの$\mu_{pre}, \sigma^2_{pre}$はこれまでのバッチから得られている平均と分散で、$ \mu_{batch}, \sigma^2_{batch} $は学習中のバッチ内での平均と分散になります。

もう一つの工夫

Deep Learning Bookによれば、バッチ正規化の工夫はこれだけではありません。「正規化」とはデータを扱いやすくする一方で、もともと持っている情報(表現)を圧縮して削減することになります。これにより各ユニットの表現力が落ちてしまうことが懸念されます。そこで、バッチ正規化では学習パラメータとして$ \gamma, \beta$を用いて以下のように出力値を補正します。
$$ \bf{y}(output) = \gamma \bf{X}' + \beta $$
この$\bf{y}$をバッチ正規化層の出力値とします。$\gamma,\beta$は新しい変数が任意の平均や分散を取ることができるような働きがあります。これは新しいパラメータが異なる学習ダイナミクス(学習に伴うパラメータの変化具合?と解釈しました)を持つことに由来します。さらに、処理前の平均は層の中で複雑な相互作用の元にあるのに対して、新しいアウトプット$\bf{y}$での平均は単に$\beta$に従うのみなので、学習効率が良いようです。逆伝播の際に$ \gamma, \beta$の勾配計算を行うことで、他のパラメータと同様に更新することができます。

Pythonで書いてみる

import numpy as np

class BatchNorm():
    def __init__(self) :
        self.mom = 0.9  #モーメンタム
        
    def forward(self, x, train_flag=False):  # 順伝播
        """
        x : 入力値
        """
        if self.mu is None:
            N, D = x.shape
            self.mu = np.zeros(D)   #平均の初期化
            self.var = np.zeros(D)  #分散の初期化
            self.gamma = np.ones(D)
            self.beta = np.ones(D)
        else:
            pass

        if train_frag:  #訓練時のみモーメンタムを考慮した平均と分散を使う
            self.size = x.shape[0]
            mu = np.average(x, axis=0)
            self.xc = x - mu  #ミニバッチ平均とミニバッチデータの差分
            var = np.var(self.xc, axis=0)
            self.std = np.sqrt(var + 1e-6)
            self.x_n = self.xc / self.std   #学習時のX'が求まった(バッチ正規化)

            self.mu_t = self.mom * self.mu_t + (1 - self.mom) * mu    #推論で用いる平均
            self.var_t = self/mom * self.var_t + (1 - self.mom) * var  #推論で用いる分散

        else:
            xc = x - self.mu_t
            self.x_n = xc / np.sqrt(self.var_t 1e-6)  #推論時のX'が求まった(バッチ正規化)
        return self.gamma * x_n + self.beta

    def back_prop(self,delta):  # 逆伝播
        """
        delta : 逆伝播で伝わってきた勾配
        """
        self.beta_d = np.sum(delta, axis=0)
        self.gamma_d = np.sum(delta * self.x_n, axis=0)
        xn_d = self.gamma * delta
        xc_d = xn_d / self.std
        std_d = -np.sum((xn_d * self.xc) / (self.std ** 2), axis=0)
        var_d = 0.5 * std_d / self.std
        xc_d += (2.0 / self.std) * self.xc * var_d
        mu_d = np.sum(xc_d, axis=0)

        return xc_d - mu_d / self.size

逆伝播が少し煩雑ですが、計算グラフを淡々と辿っていけば理解が進むかと思います。計算グラフについては、こちら(Batch Normalizationの理解)を参考にしました。添え字の_dがついているものは勾配を表しています。
このBatchNormクラスを活性化層の前に挟むことでバッチ正規化を行うことができます。

精度の比較

MNISTを使って上記プログラムを試しました。前後の層については記述を省きますが、以下のような学習になっています。
バッチ正規化3.png

以下が学習の結果です。最適化手法にはSGDとAdamを利用しました。
image.png

image.png
バッチ正規化を行うと精度が向上したことを確認しました。また、Adamではバッチ正規化により過学習を抑止できている様子を確認できました。

参考資料

Deep Learning -An MIT Press Book- Part2: 8 Optimization for Training Deep Models
Qiita Batch Normalizationの理解

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?