0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

モンテカルロ獲得関数を理解したい

0
Posted at

はじめに

ベイズ最適化では、次にどの点を評価するかを獲得関数が決めます。通常の獲得関数は1回のサイクルで1点を提案しますが、複数の候補点を同時に提案したい場合(バッチ観測)には、モンテカルロ(MC)獲得関数が使われます。

BoTorchを使っていてqExpectedImprovementoptimize_acqfといった関数を目にしたことがあったのですが、内部でどのような計算が行われているのかがよくわかっていませんでした。そこで今回、MC獲得関数の仕組みを一通り調べてまとめました。

具体的には以下の点を整理しています。

  • なぜ獲得関数をモンテカルロで近似するのか
  • 候補点 $\mathbf{x}$ はどのように初期化・更新されるのか
  • 再パラメータ化によってなぜ勾配が計算できるのか
  • RQMCとSAAはそれぞれ何のためにあるのか

前提知識:ガウス過程(GP)によるサロゲートモデル、Expected Improvementなどベイズ最適化の基本概念は既知とします。BoTorchの使用経験は問いません。

この記事は、獲得関数評価を$q$回同時に行う合同(joint)最適化を前提に書いていますが、実際には$q$回評価を繰り返す逐次(greedy)最適化でも使用されます。BoTorchでのモンテカルロ関数を使った逐次最適化に関しては、最後のほうに記載します。
こちらの方が良い結果が得られる場合もあるようです。(論文

【参考】

自分には難しかったですが、ベイズ最適化関連に詳しい人は読めるのかもしれません…

セクション2:4つの関数の関係

MC獲得関数の計算には、4つの関数が登場します。それぞれの役割を整理してから、データの流れを確認します。

目的関数

最大化したい本物の関数です。実験や高コストなシミュレーションがこれにあたります。評価コストが高いため、直接最適化することができません。

$$f_{\text{true}}(\mathbf{x})$$

サロゲートモデル

観測データ $\mathcal{D}$ をもとに $f_{\text{true}}$ を確率的に近似するモデルです。BoTorchでは通常ガウス過程(GP)が使われます。候補点 $\mathbf{x}$ を与えると、その点での関数値の事後分布を返します。

$$f_{\mathcal{D}} = p(f_{\text{true}} \mid \mathcal{D})$$

目的関数レイヤー

サロゲートモデルの出力を「最大化したい方向」に揃えるための変換レイヤーです。例えばサロゲートの出力が「小さいほど良い」場合、符号を反転させます。何も変換が不要な場合は恒等変換になります。BoTorchではIdentityMCObjectiveなどがこれにあたります。

$$g(\cdot)$$

このレイヤーはサロゲートとutility関数の間に位置する独立した層であり、省略可能です。

utility関数

ここが理解のポイントになる部分です。順を追って説明します。

$$a(\cdot; \Phi)$$

まず、$q$個の候補点を以下のようにまとめて表記します。

$$\mathbf{x} = \{\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_q\}$$

サロゲートモデル $f_{\mathcal{D}}$ にこの$q$個の候補点をまとめて渡すと、各候補点での関数値の事後分布が得られます。この事後分布は、$q$個の点をまとめた多変量ガウス分布になります。候補点を独立に扱うのではなく、まとめて一つの分布として扱うことで、候補点間の相関が保持される点が重要です。

この多変量ガウス分布から1回サンプリングすると、$q$個の値のセットが得られます。

$$\boldsymbol{\xi}^{(n)} = (\xi_1^{(n)}, \xi_2^{(n)}, \ldots, \xi_q^{(n)}), \quad \xi_i^{(n)} \sim f_{\mathcal{D}}(\mathbf{x}_i)$$

つまり $\boldsymbol{\xi}^{(n)}$ は「$q$個の候補点 $\mathbf{x}_1, \ldots, \mathbf{x}_q$ それぞれに対して、サロゲートモデルからサンプリングした値をまとめたベクトル」です。サンプル一本が、$q$個の候補点に対応する$q$次元のベクトルになります。

utility関数 $a(\cdot; \Phi)$ はこの $g(\boldsymbol{\xi}^{(n)})$ を受け取り、「どれだけ良いか」を表すスカラーに変換します。$\Phi$ はutility関数のパラメータであり、qEIの場合は現在の最良観測値 $f^*$ がこれにあたります。qEIのutility関数はReLUを使って以下のように書けます。

$$a(\boldsymbol{\xi}; \Phi) = \max_{i=1,\ldots,q}\text{ReLU}(\xi_i - f^*)$$

$$\text{ReLU}(z) = \max(z, 0)$$

ここで $f^*$ は現在の最良観測値です。q個の候補点の中で、現在の最良値を超えた改善量が最も大きいものをスカラーとして返します。

獲得関数

utility関数の期待値として定義されます。

\alpha(\mathbf{x}) = \mathbb{E}_{f_{\mathcal{D}}}\left[a\left(g(\boldsymbol{\xi}); \Phi\right)\right]

この値を最大化する $\mathbf{x}$ が、次の観測候補点として提案されます。BoTorchではこの期待値をモンテカルロ近似で計算します。

$$\alpha(\mathbf{x}) \approx \frac{1}{N}\sum_{n=1}^{N} a\left(g(\boldsymbol{\xi}^{(n)}); \Phi\right)$$

ここで $\boldsymbol{\xi}^{(n)}$ は事後分布から引いた$n$番目のサンプルです。$N$ 個のサンプルを平均することで期待値を近似します。

なぜMC近似が必要なのかは次のセクションで説明します。

セクション3:なぜMCで計算するのか

セクション2で、獲得関数は「utilityの期待値」として定義されると説明しました。この期待値は本来、事後分布に関する積分として表されます。

$$\alpha(\mathbf{x}) = \int a\left(g(\boldsymbol{\xi}); \Phi\right) p(\boldsymbol{\xi} \mid \mathcal{D}) d\boldsymbol{\xi}$$

この積分を解析的に解ければ閉じた式が得られますが、BoTorchではあえてMC近似を使っています。その理由を説明します。

q=1のときは解析的に解けた

$q=1$、つまり候補点が1点の場合を考えます。EI(Expected Improvement)のutility関数は以下です。

$$a(\xi; \Phi) = \text{ReLU}(\xi - f^*)$$

GPの事後分布はガウス分布 $\mathcal{N}(\mu, \sigma^2)$ なので、この期待値は以下の積分になります。

$$\alpha(x) = \int \text{ReLU}(\xi - f^*) \cdot \mathcal{N}(\xi; \mu, \sigma^2) d\xi$$

ReLUとガウス分布の積の積分は解析的に解くことができ、以下の閉じた式が得られます。

$$\alpha(x) = \sigma \left[\phi(z) + z\Phi(z)\right], \quad z = \frac{\mu - f^*}{\sigma}$$

ここで $\phi$ は標準正規分布の確率密度関数、$\Phi$ は累積分布関数です。この式があるため、古典的なEIは高速に計算できました。

q>1で2つの壁が現れる

候補点をq個に増やすと、2つの理由から解析的な計算が困難になります。

壁①:$\max$ の積分が解けない

qEIのutility関数は以下です。

$$a(\boldsymbol{\xi}; \Phi) = \max_{i=1,\ldots,q}\text{ReLU}(\xi_i - f^*)$$

この期待値は多変量ガウス分布に対する $\max$ の積分になります。

$$\alpha(\mathbf{x}) = \int \max_{i}\text{ReLU}(\xi_i - f^*) \cdot \mathcal{N}(\boldsymbol{\xi}; \boldsymbol{\mu}, \boldsymbol{\Sigma}) d\boldsymbol{\xi}$$

$\max$ は非線形な操作であり、多変量ガウス分布との積分を解析的に解く閉じた式は存在しません。

壁②:utility関数の非線形性

$\max$ や $\text{ReLU}$ といった非線形な変換が入ると、一般に積分が閉じた式になりません。q=1のEIが解析的に解けたのは、ReLUとガウス分布の組み合わせが偶然うまく積分できたからです。utility関数の設計が変わるたびに解析解を導出し直す必要があり、汎用性がありません。

MCで近似することの利点

MC近似を使えば、utility関数の形やqの値によらず、同じ手順で獲得関数を計算できます。

$$\alpha(\mathbf{x}) \approx \frac{1}{N}\sum_{n=1}^{N} a\left(g(\boldsymbol{\xi}^{(n)}); \Phi\right)$$

事後分布からサンプルを引いて平均するだけなので、utility関数がどんな非線形な形をしていても対応できます。どんなutility関数でも・どんなqでも同じ枠組みで計算できる汎用性が、BoTorchの良いところです。

BoTorchではこの $N$ がnum_samplesに対応します。qExpectedImprovementなどのMC獲得関数を初期化する際に指定するパラメータで、おおよそ512〜1024程度の値が使われることが多いです。$N$ が大きいほど期待値の推定精度が上がりますが、計算コストもその分増えます。

from botorch.acquisition import qExpectedImprovement
from botorch.sampling.normal import SobolQMCNormalSampler

sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1024]))  # N=1024

qEI = qExpectedImprovement(
    model=surrogate,
    best_f=best_f,
    sampler=sampler,
)

セクション4:計算の全体フロー

ここまでで登場人物の整理と、MCを使う理由が明確になりました。セクション4では、獲得関数の最適化がどのような手順で進むかを俯瞰します。各ステップの詳細は後続のセクションで説明します。

BoTorchが「次に試すべき候補点 $\mathbf{x}$」を決めるまでの流れは以下のようになっています。

  1. Sobolシーケンスで$q$個の候補点$\mathbf{x}$を初期化
  2. サロゲートモデルに$\mathbf{x}$を渡し、事後分布を得る
  3. RQMCで$\boldsymbol{\varepsilon}$をサンプリング → 再パラメータ化で$\boldsymbol{\xi}^{(1)}, \ldots, \boldsymbol{\xi}^{(N)}$を生成 → $\boldsymbol{\varepsilon}$はSAAで固定・使いまわす
  4. $g(\boldsymbol{\xi}) \to a(g(\boldsymbol{\xi});\Phi)$ でサンプルごとにスカラー化
  5. $N$個の平均 → 獲得関数の値$\alpha(\mathbf{x})$を得る
  6. 連鎖率で$\frac{\partial a}{\partial g} \cdot \frac{\partial g}{\partial h} \cdot \frac{\partial h}{\partial \mathbf{x}}$を計算
  7. L-BFGS-Bで$\mathbf{x}$を更新 → ②に戻る
  8. 収束した$\mathbf{x}$を次の観測点の候補として返す

いくつか見慣れない言葉が登場していますが、それぞれ以下のセクションで説明します。

ステップ 関連セクション
① Sobolによる初期化 セクション5
②③ 事後分布からのサンプリング セクション6
③ 再パラメータ化とε セクション7
③ RQMCとSAA セクション8
⑥ 連鎖率による勾配計算 セクション7
⑦ L-BFGS-Bによる更新とnum_restarts セクション9

一点補足します。このフローは「一つの初期点から出発した場合」の流れです。実際にはnum_restarts個の初期点から並列に②〜⑧を実行し、最終的に最も $\alpha(\mathbf{x})$ が高かった $\mathbf{x}$ を採用します。この点はセクション9で改めて整理します。

セクション5:Sobolによる初期化

勾配降下法で獲得関数を最大化するには、まず候補点 $\mathbf{x}$ の初期値を決める必要があります。BoTorchではこの初期化にSobolシーケンスを使っています。

なぜ一様乱数ではなくSobolなのか

単純な一様乱数で初期点を生成すると、偶然に点が偏って配置されることがあります。Sobolシーケンスは低discrepancy数列と呼ばれる準乱数の一種で、設計上、空間をできるだけ均一にカバーするように点が配置されます。

直感的なイメージとしては、一様乱数が「ランダムに散らばる」のに対して、Sobolシーケンスは「なるべく均等に埋まるように配置される」と考えると良いです。初期点の質が高いほど、勾配降下法が良い出発点からスタートできるため、最終的に得られる獲得関数の最大値の品質が上がります。

raw_samplesnum_restartsnum_samplesの関係

BoTorchで獲得関数を最適化する際、似た名前のパラメータが複数登場します。それぞれの役割は以下のように異なります。

パラメータ 役割
raw_samples Sobolシーケンスで生成する初期候補点の総数。おおよそ512〜2048程度
num_restarts 勾配降下法の出発点として使う候補点の数。おおよそ10〜20程度
num_samples MC近似のサンプル数N(セクション3参照)。おおよそ512〜1024程度

raw_samplesnum_restartsの関係は以下のような2段階の戦略になっています。

  1. Sobolシーケンスでraw_samples個の初期候補点を広く生成する
  2. それぞれの獲得関数値 $\alpha(\mathbf{x})$ を評価し、値が高い上位num_restarts個を選ぶ
  3. 選ばれたnum_restarts個を出発点として、それぞれ独立に勾配降下法を実行する

広く探索してから有望な点に絞り込む、という2段階の戦略です。num_samplesはこの初期化とは独立した概念であり、MC近似の精度に関わるパラメータです。

サンプルコード

optimize_acqfの引数として3つのパラメータがどこに現れるかを示します。

from botorch.optim import optimize_acqf
from botorch.acquisition import qExpectedImprovement
from botorch.sampling.normal import SobolQMCNormalSampler

# num_samplesはqEIの初期化時に指定(セクション3参照)
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1024]))
qEI = qExpectedImprovement(
    model=surrogate,
    best_f=best_f,
    sampler=sampler,
)

# raw_samplesとnum_restartsはoptimize_acqfで指定
candidate, acq_value = optimize_acqf(
    acq_function=qEI,
    bounds=bounds,
    q=2,
    num_restarts=10,   # 勾配降下法の出発点の数
    raw_samples=512,   # Sobolで生成する初期候補点の総数
)

セクション6:MC獲得関数の計算

セクション2で獲得関数の定義を説明しました。このセクションでは、サロゲートモデルの事後分布からサンプリングして獲得関数値を計算するまでの流れを、もう少し丁寧に説明します。

事後分布の構造

q個の候補点
$$\mathbf{x} = \{\mathbf{x}_1, \ldots, \mathbf{x}_q\}$$

をサロゲートモデル$f_{\mathcal{D}}$に渡すと、各候補点での関数値の事後分布が得られます。GPの事後分布はq次元の多変量ガウス分布になります。

$$p(\boldsymbol{\xi} \mid \mathcal{D}, \mathbf{x}) = \mathcal{N}(\boldsymbol{\mu}(\mathbf{x}), \boldsymbol{\Sigma}(\mathbf{x}))$$

ここで $\boldsymbol{\mu}(\mathbf{x}) \in \mathbb{R}^q$ は各候補点での事後平均、$\boldsymbol{\Sigma}(\mathbf{x}) \in \mathbb{R}^{q \times q}$ は候補点間の共分散行列です。

重要なのは、q個の候補点を独立にサンプリングするのではなく、まとめて一つの多変量ガウス分布からサンプリングする点です。共分散行列 $\boldsymbol{\Sigma}(\mathbf{x})$ が候補点間の相関を保持しているため、例えば「$\mathbf{x}_1$ と $\mathbf{x}_2$ が近い点であれば、サンプル値も似た傾向になる」という情報が反映されます。これを独立にサンプリングしてしまうと、候補点間の相関が失われ、qEIの計算が不正確になります。

サンプリングから獲得関数値までの流れ

この多変量ガウス分布からN回サンプリングすることで、N個のサンプル $\boldsymbol{\xi}^{(1)}, \ldots, \boldsymbol{\xi}^{(N)}$ が得られます。各サンプル $\boldsymbol{\xi}^{(n)}$ は「q個の候補点それぞれに対するサロゲートモデルの予測値を一本の軌跡としてまとめたもの」です。

$$\boldsymbol{\xi}^{(n)} = (\xi_1^{(n)}, \xi_2^{(n)}, \ldots, \xi_q^{(n)}), \quad n = 1, \ldots, N$$

各サンプルに対して $g$ と $a(\cdot; \Phi)$ を適用し、スカラー値を得ます。

$$s^{(n)} = a\left(g(\boldsymbol{\xi}^{(n)}); \Phi\right) \in \mathbb{R}$$

N個のスカラー値を平均したものが獲得関数の近似値です。

$$\alpha(\mathbf{x}) \approx \frac{1}{N}\sum_{n=1}^{N} s^{(n)}$$

サンプル軌跡のイメージ

下の図はBoTorchの論文から引用した、事後分布からN本のサンプル軌跡を引くイメージです。各軌跡はq個の候補点での値をまとめたものであり、utility関数はその軌跡一本を受け取って「どれだけ良いか」をスカラーとして返します。N本の軌跡に対してutilityを計算し、その平均が獲得関数値になります。

サンプルの流れ.png

Nを増やすとどうなるか

$N$(num_samples)を大きくするほど、MC近似の精度が上がり $\alpha(\mathbf{x})$ の推定が安定します。一方で計算コストも $N$ に比例して増えます。512〜1024程度が実用的なバランスとして使われることが多いです。

ただし $N$ を増やして推定を安定させるだけでは不十分です。サンプリングにランダム性がある以上、$\mathbf{x}$ を少し動かしたときに $\alpha(\mathbf{x})$ の値がガタつき、勾配降下法が不安定になる問題が残ります。この問題への対処がセクション7・8のテーマです。

セクション7:再パラメータ化と連鎖率

セクション6の末尾で、「サンプリングにランダム性がある以上、勾配降下法が不安定になる」という問題を指摘しました。このセクションではその問題を解決する再パラメータ化トリックと、それによって勾配が計算できるようになる仕組みを説明します。

問題:サンプリング操作はxに対して微分不可能

勾配降下法で $\mathbf{x}$ を更新するには、$\frac{\partial \alpha}{\partial \mathbf{x}}$ を計算する必要があります。しかし獲得関数の計算過程には以下のようなサンプリング操作が含まれています。

$$\boldsymbol{\xi}^{(n)} \sim \mathcal{N}(\boldsymbol{\mu}(\mathbf{x}), \boldsymbol{\Sigma}(\mathbf{x}))$$

この操作は $\mathbf{x}$ に対して微分不可能です。$\mathbf{x}$ を少し動かしたとき、サンプル $\boldsymbol{\xi}^{(n)}$ がどう変化するかを「微分」として表現できないため、$\alpha(\mathbf{x})$ の勾配を $\mathbf{x}$ まで伝播させることができません。

解決策:再パラメータ化トリック

ランダム性を $\mathbf{x}$ に依存しない変数 $\boldsymbol{\varepsilon}$ に分離します。

$$\boldsymbol{\xi}^{(n)} = h(\mathbf{x}, \boldsymbol{\varepsilon}^{(n)}) = \boldsymbol{\mu}(\mathbf{x}) + \mathbf{L}(\mathbf{x}),\boldsymbol{\varepsilon}^{(n)}, \quad \boldsymbol{\varepsilon}^{(n)} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$$

ここで $\mathbf{L}(\mathbf{x})$ は共分散行列 $\boldsymbol{\Sigma}(\mathbf{x})$ のコレスキー分解、すなわち $\boldsymbol{\Sigma}(\mathbf{x}) = \mathbf{L}(\mathbf{x})\mathbf{L}(\mathbf{x})^\top$ を満たす下三角行列です。

この変換のポイントは、ランダム性がすべて $\boldsymbol{\varepsilon}^{(n)}$ に押し込まれていることです。$\boldsymbol{\varepsilon}^{(n)}$ を固定してしまえば、$\boldsymbol{\xi}^{(n)}$ は $\mathbf{x}$ の決定論的な関数になります。

\boldsymbol{\xi}^{(n)} = \underbrace{\boldsymbol{\mu}(\mathbf{x})}_{\mathbf{x}\text{の関数}} + \underbrace{\mathbf{L}(\mathbf{x})}_{\mathbf{x}\text{の関数}}\underbrace{\boldsymbol{\varepsilon}^{(n)}}_{\text{固定}}

これにより、$\mathbf{x}$ に関する微分が計算できるようになります。

連鎖率で勾配を伝播させる

$\boldsymbol{\varepsilon}^{(n)}$ を固定した状態で、$\alpha(\mathbf{x})$ の $\mathbf{x}$ に関する勾配を連鎖率で展開します。

計算のパイプラインを整理すると、以下の3段階になっています。

$$\mathbf{x} \xrightarrow{h} \boldsymbol{\xi}^{(n)} \xrightarrow{g} g(\boldsymbol{\xi}^{(n)}) \xrightarrow{a(\cdot;\Phi)} s^{(n)}$$

  • $h(\mathbf{x}, \boldsymbol{\varepsilon}^{(n)}) = \boldsymbol{\mu}(\mathbf{x}) + \mathbf{L}(\mathbf{x})\boldsymbol{\varepsilon}^{(n)}$:再パラメータ化トリックの式
  • $g(\boldsymbol{\xi}^{(n)})$:目的関数レイヤーによる変換
  • $a(g(\boldsymbol{\xi}^{(n)}); \Phi)$:utility関数によるスカラー化

連鎖率を適用すると以下になります。

\frac{\partial \alpha}{\partial \mathbf{x}} = \frac{1}{N}\sum_{n=1}^{N} \underbrace{\frac{\partial a}{\partial g(\boldsymbol{\xi}^{(n)})}}_{\text{utilityの感度}} \cdot \underbrace{\frac{\partial g}{\partial \boldsymbol{\xi}^{(n)}}}_{\text{目的関数の感度}} \cdot \underbrace{\frac{\partial h}{\partial \mathbf{x}}}_{\text{xが事後分布に与える影響}}

各項の意味は以下の通りです。

意味
$\frac{\partial a}{\partial g(\boldsymbol{\xi}^{(n)})}$ $g(\boldsymbol{\xi}^{(n)})$ のどの成分を動かすと獲得関数値が上がるか
$\frac{\partial g}{\partial \boldsymbol{\xi}^{(n)}}$ サロゲートの出力が変わったとき、目的関数レイヤーの出力がどう変わるか
$\frac{\partial h}{\partial \mathbf{x}}$ $\mathbf{x}$ を動かしたとき、事後平均・事後分散がどう変わるか

この連鎖率によって、$a \to g \to h \to \mathbf{x}$ という経路で勾配が $\mathbf{x}$ まで届くようになります。実装上はPyTorchの自動微分(autograd)がこの計算を担っています。

セクション8:RQMCとSAA

セクション7で、$\boldsymbol{\varepsilon}^{(n)}$ を固定することで $\mathbf{x}$ への勾配が計算できるようになると説明しました。このセクションでは、その $\boldsymbol{\varepsilon}^{(n)}$ をどのように生成し、どのように扱うかを説明します。

RQMCとは

$\boldsymbol{\varepsilon}^{(n)}$ は標準正規分布 $\mathcal{N}(\mathbf{0}, \mathbf{I})$ からサンプリングされます。単純に一様乱数からサンプリングする通常のMCと比べて、BoTorchではRQMC(Randomized Quasi-Monte Carlo) を使っています。

RQMCはSobolシーケンスにランダムなシフトを加えたものです。セクション5で候補点 $\mathbf{x}$ の初期化にSobolシーケンスを使うと説明しましたが、RQMCはそれとは別の文脈で登場します。

セクション5のSobol セクション8のRQMC
何をサンプリングするか 候補点 $\mathbf{x}$ ランダム変数 $\boldsymbol{\varepsilon}$
目的 初期点を空間に均一に配置する MC積分の分散を低減する

RQMCを使う利点は、通常のMCよりも積分誤差が速く減ることです。通常のMCでは誤差が $O(1/\sqrt{N})$ のオーダーで減るのに対して、RQMCでは理論上それより速く減ります。つまり同じ $N$ でも、RQMCの方が獲得関数の推定精度が高くなります。

BoTorchではSobolQMCNormalSamplerがRQMCを担っており、セクション5・6のサンプルコードで登場していたものがこれにあたります。

SAAとは

SAA(Sample Average Approximation) は、生成した $\boldsymbol{\varepsilon}^{(1)}, \ldots, \boldsymbol{\varepsilon}^{(N)}$ を勾配降下法の全ステップを通じて固定する、という考え方です。

SAAを使わない場合、$\mathbf{x}$ を更新するたびに $\boldsymbol{\varepsilon}^{(n)}$ を引き直します。このとき獲得関数 $\alpha(\mathbf{x})$ は以下のような問題を抱えます。

\alpha(\mathbf{x} + \delta) \approx \frac{1}{N}\sum_{n=1}^{N} a\left(g(h(\mathbf{x}+\delta, \boldsymbol{\varepsilon}^{(n)}_{\text{新}})); \Phi\right), \quad \boldsymbol{\varepsilon}^{(n)}_{\text{新}} \neq \boldsymbol{\varepsilon}^{(n)}_{\text{旧}}

$\mathbf{x}$ を少し動かしただけなのに、サンプル自体が変わってしまうため、$\alpha(\mathbf{x})$ の値がランダムにガタつきます。これでは勾配の方向が信頼できず、勾配降下法が安定しません。

SAAでは $\boldsymbol{\varepsilon}^{(n)}$ を最初に一度だけ生成して固定します。

$$\alpha(\mathbf{x}) \approx \frac{1}{N}\sum_{n=1}^{N} a\left(g(h(\mathbf{x}, \boldsymbol{\varepsilon}^{(n)}_{\text{固定}})); \Phi\right)$$

$\boldsymbol{\varepsilon}^{(n)}$ が固定されると、$\alpha(\mathbf{x})$ は $\mathbf{x}$ だけの決定論的な関数になります。$\mathbf{x}$ を少し動かしたときの $\alpha(\mathbf{x})$ の変化が安定するため、勾配降下法が正しく機能します。

RQMCとSAAの組み合わせ

BoTorchでは、RQMCで低分散な $\boldsymbol{\varepsilon}^{(n)}$ を一度だけ生成し、それをSAAで固定して使いまわします。

  • RQMC:少ない $N$ でも精度の高い $\boldsymbol{\varepsilon}^{(n)}$ を生成する
  • SAA:生成した $\boldsymbol{\varepsilon}^{(n)}$ を固定し、$\alpha(\mathbf{x})$ を $\mathbf{x}$ の滑らかな関数にする

この2つの組み合わせによって、「精度が高く・勾配降下法が安定する」MC獲得関数の最適化が実現されています。BoTorchではSobolQMCNormalSamplerbase_samplesとしてこの固定された $\boldsymbol{\varepsilon}^{(n)}$ が保持されます。

下の図は論文から拝借したRQMC(qMC)やSAA(fixed)について、それらの手法を使用していない場合のEIのグラフになります。
左下がどちらも使っているもので、非常に滑らかで微分可能な形をしています。

RQMC.png

セクション9:勾配降下法で候補点を更新

セクション7・8で、再パラメータ化とSAAによって $\alpha(\mathbf{x})$ が $\mathbf{x}$ の滑らかな関数になり、勾配が計算できるようになることを説明しました。このセクションでは、その勾配を使って実際に $\mathbf{x}$ がどのように更新されるかを説明します。

BoTorchが使うoptimizer:L-BFGS-B

BoTorchは内部でL-BFGS-Bというoptimizerを使っています。L-BFGS-Bは準ニュートン法の一種で、通常の勾配降下法より効率的に最適化を進めます。

通常の勾配降下法は一次微分(勾配)だけを使って更新しますが、準ニュートン法は過去の勾配の履歴からヘッセ行列(二次微分)を近似し、より良い更新方向を計算します。更新式のイメージは以下です。

\mathbf{x}^{(t+1)} = \mathbf{x}^{(t)} + \eta \cdot \mathbf{H}^{-1}_t \nabla_{\mathbf{x}} \alpha(\mathbf{x}^{(t)})

ここで $\mathbf{H}^{-1}_t$ はL-BFGSによるヘッセ行列の近似逆行列です。曲率情報を活用することで、単純な勾配降下法より少ないステップ数で収束します。

名前の末尾のB(Bounded) は、変数に上下限の制約を設けられることを意味します。候補点 $\mathbf{x}$ が探索空間(bounds)を外れないように制約できるため、ベイズ最適化との相性が良いです。

num_restartsとの絡み

勾配降下法は局所最適解に陥るリスクがあります。つまり「その近辺では最大だが、全体として最大ではない点」に収束してしまう可能性があります。BoTorchはこの問題に対して、複数の初期点から並列に最適化を実行する戦略をとっています。

セクション5で説明した内容と合わせると、optimize_acqfの内部では以下の流れになっています。

  1. Sobolシーケンスでraw_samples個の初期候補点を生成
  2. 各候補点で$\alpha(\mathbf{x})$を評価
  3. $\alpha(\mathbf{x})$が高い上位num_restarts個を選択
  4. num_restarts個の初期点からそれぞれL-BFGS-Bで最適化を並列実行
  5. num_restarts個の最適化結果の中で最も$\alpha(\mathbf{x})$が高いxを採用

num_restartsを増やすほど局所最適を避けやすくなりますが、その分計算コストも増えます。おおよそ10〜20程度が実用的なバランスとして使われることが多いです。(AI談)

この流れを踏まえると、optimize_acqfの引数の意味が以下のように整理できます。

candidate, acq_value = optimize_acqf(
    acq_function=qEI,
    bounds=bounds,
    q=2,
    num_restarts=10,   # ④並列に最適化を実行する初期点の数
    raw_samples=512,   # ①Sobolで生成する初期候補点の総数
)

最終的にcandidateとして返される $\mathbf{x}$ が、5で選ばれた最良の候補点です。

まとめ

この記事で説明した概念を以下の表に整理します。

概念 役割 BoTorchでの対応
$f_{\text{true}}(\mathbf{x})$(目的関数) 最大化したい本物の関数 実験・シミュレーション
$f_{\mathcal{D}}$(サロゲートモデル) $f_{\text{true}}$ を確率的に近似 SingleTaskGPなど
$g(\cdot)$(目的関数レイヤー) サロゲート出力を最大化方向に変換 IdentityMCObjectiveなど
$a(\cdot; \Phi)$(utility関数) サンプル軌跡を「良さ」のスカラーに集約 qExpectedImprovement内部
$\alpha(\mathbf{x})$(獲得関数) utilityの期待値・最大化の対象 qExpectedImprovementなど
Sobolシーケンス 候補点 $\mathbf{x}$ の初期配置 raw_samples
num_restarts 勾配降下法の出発点数・局所最適の回避 num_restarts
MC近似のサンプル数 $N$ 獲得関数推定の精度 num_samples
再パラメータ化トリック ランダム性を $\boldsymbol{\varepsilon}$ に分離し $\mathbf{x}$ への勾配を可能にする PyTorch autogradで自動処理
RQMC $\boldsymbol{\varepsilon}$ を低分散でサンプリングする SobolQMCNormalSampler
SAA $\boldsymbol{\varepsilon}$ を固定し $\alpha(\mathbf{x})$ を滑らかにする base_samplesとして保持
L-BFGS-B 候補点 $\mathbf{x}$ の勾配ベース最適化 optimize_acqf内部

計算の全体フローの振り返り

各概念がどのステップで登場するかをセクション4のフローと対応させると以下になります。

  1. Sobolでraw_samples個の候補点$\mathbf{x}$を生成し、num_restarts個に絞り込む(セクション5)
  2. サロゲートモデル$f_{\mathcal{D}}$に$\mathbf{x}$を渡し、事後分布を得る(セクション6)
  3. RQMCで$\boldsymbol{\varepsilon}$をサンプリングし、再パラメータ化で$\boldsymbol{\xi}^{(1)}, \ldots, \boldsymbol{\xi}^{(N)}$を生成する。$\boldsymbol{\varepsilon}$はSAAで固定・使いまわす(セクション7・8)
  4. $g(\boldsymbol{\xi}) \to a(g(\boldsymbol{\xi});\Phi)$ でサンプルごとにスカラー化(セクション2・6)
  5. $N$個の平均 → $\alpha(\mathbf{x})$(セクション3・6)
  6. 連鎖率で$\frac{\partial a}{\partial g} \cdot \frac{\partial g}{\partial h} \cdot \frac{\partial h}{\partial \mathbf{x}}$を計算(セクション7)
  7. L-BFGS-Bで$\mathbf{x}$を更新し②に戻る(セクション9)
  8. num_restarts個の結果から最良の$\mathbf{x}$を採用して返す(セクション9)

補足:BoTorchでの逐次最適化

optimize_acqf()の引数にsequential=Trueを指定すると、合同最適化の代わりに逐次最適化を使用できます。

candidate, acq_value = optimize_acqf(
    acq_function=qEI,
    bounds=bounds,
    q=2,
    num_restarts=10,
    raw_samples=512,
    sequential=True,   # 逐次最適化を有効化
)

逐次最適化では、$q$個の候補点を一度にまとめて最適化するのではなく、1点ずつ順番に決定します。具体的には以下の流れになります。

  1. 1点目の候補点 $\mathbf{x}_1$ をMC獲得関数と勾配降下法(L-BFGS-B)で最適化して決定する
  2. 決定した $\mathbf{x}_1$ を「既に観測する予定の点」として固定し、残りの候補点に対する獲得関数を更新する
  3. 更新された獲得関数で $\mathbf{x}_2$ を最適化して決定する
  4. これを $q$ 点分繰り返す

各ステップのMC獲得関数の最適化には、本記事で説明した再パラメータ化・RQMC・SAAの仕組みがそのまま使われます。

合同最適化と比べて計算コストは増えますが、1点ずつ最適化することで局所最適に陥りにくく、より良い候補点セットが得られる場合があります(参考論文)。

公式ドキュメントでは、適度な$q$ではjoint最適化でも良いが、大きくなってきたらgreedy最適化の方が良さそうみたいなことが書いてあります。

BoTorchのMC獲得関数は、「どんなutility関数でも・どんなqでも同じ枠組みで計算できる汎用性」を実現するために、再パラメータ化・RQMC・SAAという複数の工夫を組み合わせた設計になっています。この記事がBoTorchの内部動作を理解する手がかりになれば幸いです。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?