2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ニューラルネットワークの予測の不確実性(Stochastic gradient Langevin dynamics・概要編)

Posted at

はじめに

Stochastic gradient Langevin dynamics (SGLD) でニューラルネットワークの予測の不確実性を算出する手法を紹介します。

Stochastic gradient Langevin dynamics

確率的勾配降下法とランジュバン動力学によるを組み合わせて、パラメータの事後分布をサンプリングする手法です。
SGLDの更新式に従ってパラメータを更新し、その過程のパラメータを事後分布からのサンプリングとして近似的に使用します。

バッチサイズを$b$、データセットのサイズを$n$、パラメータの更新量は下記になります。
$$
\Delta \theta = \frac{\alpha}{2}\Bigl(\frac{n}{b}\sum_i \nabla \log p(y_i|x_i,\theta) + \nabla \log p(\theta)\Bigr) + \sqrt{\alpha} \mathcal{N}(0,\rm{I})
$$
ここで、$\alpha$はSGLDのパラメータです。
上式は確率的勾配降下法によるパラメータの更新に正規分布に従うノイズを付加した形になっています。

パラメータ$\theta$でのデータ$x$の予測を$p(y|x,\theta)=\mathcal{N}(f(x;\theta),\sigma^2)$、正解を$y$とすると、$\log p(y|x,\theta)=-\frac{\ln 2\pi\sigma^2}{2} - \frac{(y-f(x;\theta))^2}{2\sigma^2}$となります。
つまり、この場合2乗誤差を最小化すればよいことになります。
また、$p(\theta)$に関しても平均0の正規分布を仮定すれば、パラメータのL2ノルムを最小化すればよいことになります。

DP-SGDとの関係

SGLDの更新式は、差分プライバシーを保証した確率的勾配法(DP-SGD)と等価であることが知られています。
DP-SGDの更新式は下記で定義されます。

\Delta \theta = \eta \biggl( \frac{1}{b} \sum_{i} {\rm clip}_C (\nabla \log p(y_i|x_i, \theta)) + \frac{1}{n}\nabla_{\theta} \log p(\theta) + \mathcal{N}(0,\sigma^2c^2\rm{I}) \biggr) 

ここで、${\rm clip}_C$はベクトルのL2ノルムを最大$C$に制限する関数、$\eta$は学習率、$c$と$\sigma$はDP-SGDのパラメータです。
DP-SGDのパラメータが$C=\frac{b\sqrt{2}}{\sigma \sqrt{\eta n}}$を満たす際に、SGLDと等価になります。

不確実性の算出

推論結果の事後分布は、パラメータの分布$p(\theta)$を考え、積分で$\theta$を消去することで求めます。
$$
p(y|x) = \int p(y|x, \theta)p(\theta) {\rm d}\theta
$$
この積分をSGLDでパラメータをサンプリングした結果${\theta_1, \cdots, \theta_n}$を用いて近似します。

事後分布の期待値と分散は下記の通り求めることができます。
$$
\mu(y) = \frac{1}{n}\sum_{i=1}^n f(x;\theta_i)
$$
$$
{\rm Var}(y) = \sigma^2 + \frac{1}{n}\sum_{i=1}^n f(x;\theta_i)^T f(x;\theta_i) - \mu(y)^T \mu(y)
$$

参考資料

  • M. Welling and Y. W. Teh, Bayesian learning via stochastic gradient Langevin dynamics, ICML, 2011.
  • M. Abadi, et al., Deep learning with differential privacy, ACM CCS, 2016.
  • B. Li, et al., On connecting stochastic gradient MCMC and differential privacy, AISTATS, 2019.
2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?