VAEとシングルセル解析
この記事は、Pythonで実践 生命科学データの機械学習 (https://www.yodosha.co.jp/yodobook/book/9784758122634/) の内容を含んでいます。
私は現在バイオインフォマティクス研究室に所属する学生です。
勉強した事をアウトプットする場として用いていますため、何卒ご理解のほどよろしくお願いいたします。
(>人<;)
変分オートエンコーダ(VAE: Variational Autoencoder)の解説
VAEとは?
VAE(Variational Autoencoder)は、データを圧縮(エンコード)し、その圧縮された情報からデータを再構成(デコード)できる深層学習モデルです。
特に、生成モデルの一種として新しいデータを生成する能力があるのが特徴です。
例えば、VAEは手書き数字のデータを学習して、新しい手書き数字を生成できます。
VAEの必要性
- 次元削減:高次元データを低次元に圧縮し、データの本質を理解する。
- 生成能力:新しいデータを生成できる(例:AIによる顔画像生成)。
- 連続的な潜在空間:潜在空間が連続的なため、データの意味的な変化を理解しやすい。
VAEの仕組み
1. オートエンコーダの基本
- エンコーダ (Encoder):入力データ(例:画像)を潜在変数 (z) に圧縮。
- デコーダ (Decoder):潜在変数 (z) から元のデータを再構築。
オートエンコーダはデータの特徴をつかむのが得意ですが、通常のオートエンコーダは潜在空間がバラバラで、生成がうまくいきません。
2. VAEの改良ポイント
VAEは、次のようにオートエンコーダを改良します:
- 潜在変数を確率分布としてモデル化:$( z )$ を点でなく、平均 $(\mu)$ と分散 $(\sigma^2)$ で表される確率分布で扱う。
- KLダイバージェンス:潜在変数の分布が標準正規分布(平均0、分散1)に近づくように学習。
この仕組みにより、潜在空間が連続的になり、新しいデータをスムーズに生成できます。
数式的なイメージ
潜在変数 (z) の生成
$
[
z \sim \mathcal{N}(\mu, \sigma^2)
]
$
エンコーダはデータ (x) を入力として、$ (\mu) $と $(\sigma^2)$ を出力します。
次に次の式で (z) をサンプルします:
$
[
z = \mu + \sigma \cdot \epsilon,\ \epsilon \sim \mathcal{N}(0, 1)
]
$
学習の目的(損失関数)
VAEは次の損失関数を最小化します:
$\mathcal{L}$ = ${\text{再構成誤差}}$ + $\text{KLダイバージェンス}$
- 再構成誤差:元のデータとデコードされたデータの差を測定。
- KLダイバージェンス:潜在変数の分布を正規分布に近づけるために必要。
VAEの直感的な理解
例えて言うなら、VAEは 「画家がスケッチブックを見て、新しいオリジナルな絵を描けるように訓練するプロセス」 です。
- エンコーダ:スケッチの特徴を理解する。
- デコーダ:その特徴をもとに新しい絵を描く。
- 潜在空間が正規分布に従うため、見たことのない絵でも自然な形で生成可能。
応用例
- 画像生成:手書き数字(MNIST)や顔画像(CelebA)の生成。
- 異常検知:正常データのパターンを学習し、異常データを検出。
- バイオインフォマティクス:ゲノムデータやバイオマーカーの特徴抽出や生成。
変分オートエンコーダ(VAE)のエンコーダーとデコーダーの構造
はじめに
変分オートエンコーダ(VAE)は、データを圧縮し、その圧縮情報から新しいデータを生成できる 深層学習の生成モデル です。
このVAEの中核となるのが エンコーダー(Encoder) と デコーダー(Decoder) です。
🏗️ VAEの基本構造
VAEは エンコーダー・デコーダー・潜在空間 の3つの主要な部分から構成されます。
+-----------------------+
入力データ (x) → | エンコーダー | → 潜在変数 (z)
+-----------------------+
↓
+-----------------------+
潜在変数 (z) → | デコーダー | → 出力データ (x')
+-----------------------+
それぞれの役割は以下のとおりです:
構成要素 | 役割 |
---|---|
エンコーダー | 入力データ (x) を小さな潜在変数 (z) に圧縮する |
潜在空間 | エンコーダーによって得られた情報を分布として表現する |
デコーダー | 潜在変数 (z) から元のデータ (x') を再構築する |
エンコーダーの役割と構造
エンコーダーは、入力データを 圧縮して潜在空間へマッピング する役割を持っています。
通常のオートエンコーダでは、エンコーダーは単にデータを低次元に変換するだけですが、VAEでは 確率的な分布を学習 するため、出力として 平均 $(\mu)$ と 分散 $(\sigma^2)$ を求めます。
エンコーダーの構造
以下のような ニューラルネットワーク で設計されます:
- 入力層:データ(画像、テキストなど)を受け取る
- 隠れ層(数層の全結合層や畳み込み層):特徴を抽出
-
出力層:
- 平均 $(\mu)$ の計算
- 分散 $(\sigma^2)$ の計算
+-------------------+
データ (x) → | 隠れ層 (NN) | → 平均 μ
| |
| 隠れ層 (NN) | → 分散 σ^2
+-------------------+
この $(\mu)$ と $(\sigma^2)$ を用いて、 再パラメータ化トリック によって潜在変数 $(z)$ をサンプリングします。
✨ 再パラメータ化トリック
VAEでは、潜在変数 (z) は 確率分布としてモデル化 されます。
そのため、通常のニューラルネットワークでは学習しにくくなります。
これを解決するのが「再パラメータ化トリック」です。
$$
z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, 1)
$$
$(\epsilon)$ は標準正規分布からサンプリングされ、これにより 勾配を伝播できるようにする ことで、VAEを学習可能にしています。
デコーダーの役割と構造
デコーダーは、潜在変数 $(z)$ を元のデータ $(x')$ に復元する 役割を持っています。
これは、 エンコーダーの逆の処理 を行うと考えられます。
デコーダーの構造
デコーダーは、次のような ニューラルネットワーク で構成されます:
- 入力層:潜在変数 (z) を受け取る
- 隠れ層(数層の全結合層や畳み込み層):特徴を復元
- 出力層:元のデータの形状に戻す
+-------------------+
潜在変数 (z) → | 隠れ層 (NN) | → 出力データ x'
+-------------------+
デコーダーは、「特徴をデコードして元のデータに復元する」ため、エンコーダーとは逆の構造を持つことが一般的です。
- エンコーダー:データを圧縮し、潜在変数 $(z)$ を生成
- 潜在空間:確率分布として表現
- デコーダー:潜在変数 $(z)$ から元のデータを復元
まとめ
- エンコーダー:データを圧縮し、$(\mu)$(平均)と$(\sigma^2)$(分散)を出力
- 再パラメータ化トリック を使って潜在変数 $(z)$ をサンプリング
- デコーダー:潜在変数 $(z)$ を元のデータに復元
VAEの学習、実践
変分オートエンコーダ(VAE)の学習方法・理論・デコーダの尤度関数の解説
- VAEの学習方法(実装とアルゴリズム)
- VAEの理論(確率モデル・損失関数)
- デコーダの尤度関数(確率分布の観点から解説)
1. VAEの学習方法(実装とアルゴリズム)
VAEの学習は、通常のオートエンコーダと異なり、確率分布に基づいた学習 を行います。
そのため、以下の 3つの主要なステップ を実行します。
🏗️ VAEの学習プロセス
- 入力データ ( x ) をエンコーダーで潜在変数 ( z ) に変換
- 潜在変数 ( z ) をデコーダーで元のデータ ( x' ) に復元
- 損失関数(再構成誤差 + KLダイバージェンス)を最適化
🔹 学習の流れ(PyTorchで実装)
import torch
import torch.nn as nn
import torch.optim as optim
# エンコーダ・デコーダの定義(省略)
# VAEの学習ループ
def train_vae(model, dataloader, optimizer, num_epochs=10):
for epoch in range(num_epochs):
for x in dataloader:
optimizer.zero_grad()
# 1. エンコーダーを通して潜在変数を取得
mu, logvar = model.encoder(x)
# 2. 再パラメータ化トリックでzをサンプリング
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
# 3. デコーダーで再構成
x_recon = model.decoder(z)
# 4. 損失関数を計算(再構成誤差 + KLダイバージェンス)
recon_loss = torch.nn.functional.mse_loss(x_recon, x, reduction='sum')
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl_div
# 5. 逆伝播と最適化
loss.backward()
optimizer.step()
VAEの理論(確率モデル・損失関数)の解説!
1. VAEの確率モデル
通常のオートエンコーダでは、入力データ $( x )$ を 決定論的に 圧縮・復元します。
しかし、VAEでは、潜在変数 $( z )$ を確率分布として扱う ことで、より一般化能力の高い生成モデルを構築します。
VAEの確率モデルの構造
VAEは、確率的生成モデル の形で定義され、以下のような階層的な分布構造を持ちます。
$
[
p(x, z) = p(x|z) p(z)
]
$
- $( p(z) )$ :事前分布(通常は標準正規分布 $( \mathcal{N}(0, I) ))$
- $( p(x|z) )$ :デコーダの条件付き分布$(観測データ ( x ) を生成)$
このモデルでは、潜在変数 $( z )$ を用いて、入力データ $( x )$ を生成することができます。
2. 変分推論とELBO(Evidence Lower Bound)
VAEの学習の目標は、観測データ $( x )$ の確率分布 $( p(x) )$ を最大化すること です。
しかし、実際には ( p(x) ) は積分計算が難しく、直接最大化するのは困難です。
$
[
p(x) = \int p(x|z) p(z) dz
]
$
そこで、近似推論 を用いて、代わりに ELBO(証拠下界) を最大化します。
$\log p(x)$ = $\mathbb{E}$${q(z|x)}[\log p(x|z)]$ - D${KL}$$(q(z|x) || p(z))
]$
ELBOの2つの項
-
再構成誤差$(( \mathbb{E}_{q(z|x)}[\log p(x|z)] ))$
- デコーダが ( z ) から入力 ( x ) をどれだけ正しく復元できたかを表す。
- 例:MSE(平均二乗誤差)やバイナリクロスエントロピーを使う。
-
KLダイバージェンス$(( D_{KL}(q(z|x) || p(z)) ))$
- エンコーダ $( q(z|x) )$ が、事前分布 $( p(z) ) $にどれだけ近いかを測定。
- 事前分布$ ( p(z) ) $を標準正規分布にすることで、潜在空間を整理する。
3. VAEの損失関数
VAEの損失関数は、ELBOを最大化する代わりに、マイナスELBOを最小化 する形で定義されます。
$\mathcal{L} $= $\underbrace{\text{再構成誤差}}$${\text{データを正しく再現}} + \underbrace{\text{KLダイバージェンス}}$${\text{潜在分布の正規化}}$
✔️ 再構成誤差
$
[
\mathbb{E}_{q(z|x)}[\log p(x|z)]
]
$
- ( x ) を ( z ) から生成したときの誤差を測る。
- MSE(平均二乗誤差)やバイナリクロスエントロピー(BCE)が使われる。
✔️ KLダイバージェンス
$
[
D_{KL}(q(z|x) || p(z))
]
$
- $( q(z|x) )(エンコーダの出力)$を、事前分布 $( p(z) )$(通常は標準正規分布)に近づける。
最終的なVAEの損失関数(PyTorch実装)
import torch
import torch.nn.functional as F
def vae_loss(x_recon, x, mu, logvar):
# 再構成誤差(MSE)
recon_loss = F.mse_loss(x_recon, x, reduction='sum')
# KLダイバージェンス
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# 合計の損失
loss = recon_loss + kl_div
return loss
VAEのデコーダの尤度関数の解説
1. VAEのデコーダは確率モデル
通常のオートエンコーダ(AE)では、デコーダの出力は 決定論的(固定の値) です。
しかし、VAEでは、確率分布 ( p(x|z) ) をモデル化 し、データを生成します。
デコーダの確率モデル
デコーダの出力 ( p(x|z) ) は、データの性質に応じて異なる確率分布を仮定 します。
データの種類 | 確率分布 |
---|---|
連続値データ(画像、数値データ) | ガウス分布(正規分布) |
バイナリデータ(0 or 1) | ベルヌーイ分布 |
この尤度関数を適切に選ぶことが、VAEの成功にとって重要 です。
2. ガウス分布を用いたデコーダ(連続データ向け)
連続値データ(画像、数値データ)を扱う場合、デコーダの出力はガウス分布(正規分布) を仮定します。
$
[
p(x|z) = \mathcal{N}(x; \mu(z), \sigma^2 I)
]
$
- $( \mu(z) )$ :潜在変数 $( z ) $から得られる平均
- $( \sigma^2 )$ :データのばらつきを表す分散
この分布を用いることで、生成されたデータが滑らかになる というメリットがあります。
PyTorchによるガウス分布デコーダの実装
import torch
import torch.nn as nn
class GaussianDecoder(nn.Module):
def __init__(self, z_dim, h_dim, x_dim):
super(GaussianDecoder, self).__init__()
self.fc = nn.Linear(z_dim, h_dim)
self.fc_mu = nn.Linear(h_dim, x_dim) # 平均
self.fc_logvar = nn.Linear(h_dim, x_dim) # 対数分散
def forward(self, z):
h = torch.relu(self.fc(z))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
std = torch.exp(0.5 * logvar) # 標準偏差を計算
return mu, std