本記事は京都大学人工知能研究会KaiRA Advent Calendar 2023 21日目の記事です。
Diffusion Modelに対してブロック単位でNASを行う、「Lightweight Diffusion Models with Distillation-Based Block Neural Architecture Search」という論文について紹介します。
そもそもNAS(Neural Architecture Search)とはニューラルネットのアーキテクチャ(実質はハイパーパラメータ)を適切なものにしようと探索する手法の総称です。結果として人手で決めるよりも精度が良くなったりパラメータ数を少なくできたりします。
そしてこの論文では学習済みのDiffusion ModelのU-Netに対してブロック単位のNASを行う手法を提案しています。ブロック単位で探索を行うため、ネットワーク全体で探索を行うよりも探索幅が小さく効率的な探索を行うことが可能になります。結果、生成画像の品質はそのままにパラメータ数を半分まで落とすことに成功しています。
手法
本論文で提案している手法は大きく分けて3ステップに分かれます。
1ステップ目では探索するアーキテクチャを内包するスーパーネットと呼ばれるネットワークを蒸留によって学習します。
2ステップ目ではスーパーネットを評価し、損失値がまあまあ小さくなるアーキテクチャのうち最もパラメータ数の少ないものを選びます。
3ステップ目ではそのアーキテクチャを採用した上で損失が十分下がるまで再学習を行います。
それぞれもう少し詳しく説明していきます。
スーパーネットの学習
上述の通り、U-Netのスーパーネットをブロック単位で蒸留によって学習していきます。すなわち、生徒モデルの各ブロックへの入力として教師モデルの各ブロックへの入力をそのまま持ってきて、生徒モデルの各ブロックの出力が教師モデルの各ブロックの出力と同じになるように学習します。具体的には各ブロックの出力の二乗和誤差を最小化するよう最適化されます。論文の数式で表すと以下のようになります。
\mathcal{L}_{\mathrm{train}} = \| \mathrm{Block}_{\alpha}\left(\mathbf{X_i}\right) - \mathbf{Y_i} \|^2_2
ただし$\mathrm{Block}_{\alpha}$はサンプリングされたアーキテクチャ、$\mathbf{X_i}, \mathbf{Y_i}$はそれぞれ教師モデルの$i$番目のブロックの入出力です。
そしてスーパーネットからのアーキテクチャのサンプリングにはrandom single-path strategy(論文)を用いていると書いています。例えばチャネル数を探索したい場合には、あらかじめ最大のチャネル数$c_{max}$を決めておいてチャネル数$c_{max}$のパラメータを用意しておきます。そしてforward計算をする時にチャネル数$c$をランダムに選択し、用意しておいたパラメータの最初の$c$チャネル分のみを用いて計算します。この方法により各アーキテクチャが効率的に学習されます。
スーパーネットの評価
スーパーネットに含まれる各アーキテクチャのうち、最も「コスパの良い」ものを選びます。「コスパの良い」アーキテクチャとは、損失値が教師モデルのアーキテクチャを採用した場合とほぼ同等かそれ以下のアーキテクチャのうち、最もパラメータ数の少ないものとしています。論文の数式で書けば、
\mathrm{argmin}_{\alpha_i \in \mathcal{A}_i} \, \mathrm{Cost}(\alpha_i)
\mathcal{L}_{\mathrm{val}}(w_i, \alpha_i; \mathbf{X_i}, \mathbf{Y_i}) \leq r \mathcal{L}_{\mathrm{val}}(w_i, \alpha_{\mathrm{base}}; \mathbf{X_i}, \mathbf{Y_i})
となります。
ただし、
- $\mathcal{A}_i$…探索するアーキテクチャの集合
- $\mathrm{Cost}(\alpha_i)$…アーキテクチャ$\alpha_i$のコスト(パラメータ数)
- $\mathcal{L}_{\mathrm{val}}(w_i, \alpha_i; \mathbf{X_i}, \mathbf{Y_i})$…ブロック$i$でのアーキテクチャ$\alpha_i$における二乗和誤差
- $\alpha_{\mathrm{base}}$…教師モデルと同じアーキテクチャ
です。また、$r$は教師モデルの損失よりもほんの少し大きい損失を許容するためのパラメータで、$1.02$のような値を用います。
このステップの計算は、規模の大きな評価データを用いる必要はなく、いくらかのミニバッチで計算すれば良いそうです。そのため、上記のスーパーネットの学習と合わせてもそこまで計算コストがかからないとのことです。(具体的にはA100一つで半日程度)
再学習
アーキテクチャを決定した後、十分損失が下がるまで再学習を行います。具体的には以下の式で学習が行われます。
\mathcal{L}_{\mathrm{retrain}} = \gamma \cdot (1-\beta)\mathcal{L}_{\mathrm{dis}} + \beta\mathcal{L}_{\mathrm{ori}}
ただし、
- $\mathcal{L}_{\mathrm{dis}}$…上記同様、蒸留の損失(二乗和誤差)
- $\mathcal{L}_{\mathrm{ori}}$…Diffusion Modelの損失(ノイズ同士の二乗和誤差)
- $\gamma$…重み係数
- $\beta$…バランシングのための係数
です。$\beta$は$0$から$1$まで線形、もしくはステップ関数的に変化させます。
実験
設定
Conv層のカーネルサイズを探索しています。探索する範囲は$1, 3, 5$としています。(元のカーネルサイズは$3$)
モデルはピクセル空間におけるDiffusion Modelで検証しており、データセットにはCIFAR-10、CelebA、LSUN-churchを用いています。
結果
CIFAR-10における各設定でのFIDの比較です。論文の表1を引用します。
S1、S2はそれぞれ評価時に$r=1.02, 1.00$とした場合を表しています。また、linearとstepは再学習時に$\beta$を線形に変化させた場合とステップ的に変化させた場合を表しています。S2(linear)は教師モデルよりもFIDが低くなっており、dpm-20を除き3種類の中で最小になっています。
また、Ablation Studyとして、ランダムにアーキテクチャを決定した場合(Random)と比較してもdpm-20を除きS2の方がFIDが良く、探索に効果があることが確かめられます。
さらに、再学習時に蒸留損失を使わないNo-Dis-S2、$\beta$固定のFixed-Loss-S1の2種類も検証しています。どちらもFIDは悪化しているため、再学習時にも蒸留損失を用いること、$\beta$を変化させることにも効果があることが確かめられます。
その他、論文にはCelebA、LSUN-churchを用いた場合にはstepの方がFIDが良くなることが示されています。
最後に、Latent Diffusion Modelに対しても検証しています。論文の表4を引用します。
どちらもFIDは教師モデルより僅かに悪化していますが、パラメータを約半分まで減らすことができています。
最近はStable Diffusionでの1-stepでの画像生成が可能になってきたため、今後は「1-stepにどれだけ時間がかかるか」という点はますます重要になってくると思います。