AugMix
AugMixの概要
AugMixは「頑健性(robustness)」と「分類器の不確実性の推定(uncertainty estimates of image classifiers)」の向上について取り組んだデータ拡張(Data Augmentation)の手法に関する研究です。
データ拡張(Data Augmentation)手法の比較(AugMix論文 Fig.1)
上図から確認できるようにAugMixは他のAugmentationの手法に比べて自然かつ豊富な種類のAugmentationを生成することができます。
AugMix論文 Fig.3
また、多くのAugmentationの手法では上図のように連続的に変換処理を適用しますが、これによって元の画像からかけ離れた現実的でない画像が生成される場合もあります。AugMixではこのようなdegradationを軽減するにあたってMixUpのようにサンプルの合成を行います。
AugMixのアルゴリズム
AugMixの処理の概要 (AugMix論文 Fig.4)
AugMixの処理の概要は上図から確認できます。大まかに、「画像にいくつかのパターンで変換を並行で適用し、サンプルを重み付け和で合成している」と理解すれば良いです。計算の詳細は下記の疑似コード(Pseudo Code)を確認すると良いです。
AugMixの疑似コード(AugMix論文 Algorithm)
上記の疑似コードから、下記などを読み取ることができます。
・$k$個の操作(operation)をサンプリングし($op_1, op_2, op_3$)、連続する操作$op_{12} = op_2 \circ op1, , op_{123} = op_3 \circ op_2 \circ op_1$(合成関数)を作成し、$op_1, op_{12}, op_{123}$のどれかから元画像($X_{orig}$)に適用させる操作を確率的に選ぶ(l.6〜l.8)。
・変換操作を元画像に適用した結果($\mathrm{chain}(x_{orig})$)をディリクレ分布$\mathrm{Dirichlet}(\alpha, \cdots \alpha)$からサンプリングした$(w_1, \cdots , w_k)$(l.4)を用いて重み付け和$x_{aug}$を計算する(l.9)。
・ベータ分布$\mathrm{Beta}(\alpha, \alpha)$からサンプリングした$m$(l.11)を用いて元画像$x_{orig}$と重み付け和$x_{aug}$を合成する(l.12)。
・$\alpha=1$のとき(l.2)ベータ分布$\mathrm{Beta}(\alpha, \alpha)$は一様分布となるので$\alpha=1$のとき$m$は一様乱数である。
・学習のLossには画像分類のlossに$\mathcal{L}$に一致性(Consistency)を取り扱ったJSダイバージェンス(Jensen-Shannon Divergence)を加えたものを用いる(l.17)。
また、17行目に出てくるJSダイバージェンスを$JS(p_{orig}; p_{augmix1}; p_{augmix2})$のようにおくと、$JS(p_{orig}; p_{augmix1}; p_{augmix2})$は下記のように定義されます。
\begin{align}
JS(p_{orig}; p_{augmix1}; p_{augmix2}) &= \frac{1}{3} \left( KL(p_{orig}||M) + KL(p_{augmix1}||M) + KL(p_{augmix2}||M) \right) \\
M &= \frac{1}{3} (p_{orig} + p_{augmix1} + p_{augmix2}) \\
p_{orig} &= \hat{p}(y|x_{orig}) \\
p_{augmix1} &= \hat{p}(y|x_{augmix1}) \\
p_{augmix2} &= \hat{p}(y|x_{augmix2})
\end{align}
AugMax
AugMaxの概要
AugMaxのloss
まず、下記のように文字の定義を行います。
文字 | 意味 |
---|---|
$\mathbf{x} \in \mathbb{R}^{d}$ | 画像に対応するベクトル |
$\mathbf{y} \in \mathbb{R}^{c}$ | ラベルに対応するベクトル($c$はラベル数に対応) |
$f(\mathbf{x}; \theta): \mathbb{R}^{d} \rightarrow \mathbb{R}^{c}$ | classifier(predictor)、 |
このとき、AugMaxの目的関数($\mathrm{Objective}$)は下記のように定義されます。
\begin{align}
\mathrm{Objective} &= \frac{1}{2} \mathbb{E}_{(\mathbf{x}, \mathbf{y}) \sim \mathcal{D}} \left[ \mathcal{L}(f(\mathbf{x}^{*}; \theta),\mathbf{y}) + \mathcal{L}(f(\mathbf{x}; \theta),\mathbf{y}) \right] + \lambda \mathcal{L}_{c}(\mathbf{x}, \mathbf{x}^{*}) \\
\mathbf{x}^{*} &= g(\mathbf{x}; m^{*}, \mathbf{w}^{*}) \\
\mathbf{w}^{*} &= \sigma(\mathbf{p}^{*}) \\
m^{*}, \mathbf{p}^{*} &= \mathrm{argmax}_{m \in [0,1], \mathbf{p} \in \mathbb{R}^{b}} \,\, \mathcal{L}(f(g(\mathbf{x}_{orig}; m, \sigma(\mathbf{p})); \theta), \mathbf{y}) \\
\mathcal{L}_{c}(\mathbf{x}, \mathbf{x}^{*}) &= JS(f(\mathbf{x}; \theta), f(\tilde{\mathbf{x}}; \theta), f(\mathbf{x}^{*}; \theta))
\end{align}
上記は$\theta$に関する最小化と$m, \mathbf{p}$に関する最大化を同時に行うのでGAN(Generative Adversarial Network)と同様にminimaxの最適化を行います(基本的には交互にパラメータを学習させることが多いです)。