0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

[機械学習/深層学習] VAEを用いた顔画像の欠損補完を試してみた(CelebA)

Posted at

 変分オートエンコーダ(VAE)は、画像生成モデルとして知られていますが、欠損補完(inpainting)やノイズ除去などの「画像補正タスク」にも利用できます。
本記事では、CelebA(顔画像データセット)を用いて、顔画像の一部を意図的に欠損させて、VAEでその欠損部分を補完してみました。
 理論の厳密さよりも 「実際に動かして挙動を理解する」 ことを目的としています。

最終的にこのようにoriginal(元の画像)、masked(欠損画像)、recon(修復、補完結果)を表示させるところまで試します。まだ結果がぼやけていますがハマった点を最後に記載します。
スクリーンショット 2026-01-15 0.35.04.png

前提

  • 実行環境はGoogle Colab。ランタイムはPython3(T4 GPU)を使用
     ※ 参照:機械学習・深層学習を勉強する際の検証用環境について
  • 数学的知識や用語の説明について、参考文献やリンクを最下部に掲載 (本記事内で詳細には解説しませんが、流れや実施内容がわかるようにしたいと思います)

やったことの概要

今回行ったことをまとめると以下の通り。

  • CelebAデータセットを用意
  • 画像を 64×64 に前処理
  • 画像中央をマスクして欠損画像を作成
  • Conv-VAE を用いて欠損補完を学習
  • 再構成結果を可視化
  • 発生した典型的な問題を修正

全体の流れ

original image
   ↓
preprocess (crop, resize)
   ↓
mask(欠損画像を作る)
   ↓
Conv-VAE に入力
   ↓
reconstruction(欠損補完)

実装

1. データセットの準備(CelebA)

CelebAは約20万枚の顔画像からなるフリーのデータセット。

今回は Colab で扱いやすいように、
Hugging Face Datasets のミラーを利用した。

!pip -q install -U datasets pillow tqdm

import os, math, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from PIL import Image
# ---------------------------
# 1 Repro / Device
# ---------------------------
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# ---------------------------
# 2 Load cifar10 (free)
# ---------------------------
from datasets import load_dataset

train_hf = load_dataset("nielsr/CelebA-faces", split="train[:30000]")
valid_hf = load_dataset("nielsr/CelebA-faces", split="train[30000:33000]")

2. 前処理と欠損画像の作成

顔が中央に来るように中央クロップし、64×64 にリサイズする。また画像中央を欠損させるマスクを作成。

# ---------------------------
# 3 Dataset wrapper
# ---------------------------
class CelebAInpaintDataset(Dataset):
    def __init__(self, hf_ds, image_size=64):
        self.ds = hf_ds
        self.image_size = image_size

    def __len__(self):
        return len(self.ds)

    def _preprocess(self, pil_img: Image.Image) -> torch.Tensor:
        # CelebAは顔が中心に寄ってるので、まず中央Crop→64に縮小が無難
        # 元は178x218。中央を正方形にしてから64へ。
        w, h = pil_img.size
        side = min(w, h)
        left = (w - side) // 2
        top  = (h - side) // 2
        pil_img = pil_img.crop((left, top, left + side, top + side))
        pil_img = pil_img.resize((self.image_size, self.image_size), Image.BILINEAR)

        arr = np.array(pil_img).astype(np.float32) / 255.0  # (H,W,3) [0,1]
        arr = np.transpose(arr, (2, 0, 1))                  # (3,H,W)
        return torch.from_numpy(arr)

    def _center_mask(self, x: torch.Tensor) -> torch.Tensor:
        # x: (3, H, W)
        _, H, W = x.shape
        mask = torch.ones_like(x)
        h0, h1 = H // 3, 2 * H // 3
        w0, w1 = W // 3, 2 * W // 3
        mask[:, h0:h1, w0:w1] = 0.0
        return mask

    def __getitem__(self, idx):
        img = self.ds[idx]["image"]  # PIL
        x = self._preprocess(img)    # (3,64,64)

        mask = self._center_mask(x)
        x_masked = x * mask

        return x_masked, x, mask     # 入力(欠損), 正解(元), mask

# ---------------------------
# 4 DataLoader
# ---------------------------
image_size = 64
batch_size = 128

train_ds = CelebAInpaintDataset(train_hf, image_size=image_size)
valid_ds = CelebAInpaintDataset(valid_hf, image_size=image_size)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

print("train:", len(train_ds), "valid:", len(valid_ds))
  • 入力:欠損画像(masked)
  • 教師:元の画像(original)
    という形で学習させる。

3. モデル構成(Conv-VAE)

MLPでは表現力が不足するため、畳み込みベースのVAEを使用。

Encoder

  • Conv → ReLU を重ねて特徴抽出
  • 潜在変数の平均 μ と分散 logσ² を出力

Decoder

  • 潜在変数 z から ConvTranspose で画像を復元
  • 出力は logits(sigmoidしない)

損失関数(VAE)

# ---------------------------
# 5 Conv-VAE
#   - decoderはlogits出力(sigmoidしない)
#   - lossは BCEWithLogits + KL
# ---------------------------
class ConvVAE(nn.Module):
    def __init__(self, z_dim=64, img_ch=3):
        super().__init__()
        self.z_dim = z_dim

        # Encoder: (3,64,64) -> (256,4,4)
        self.enc = nn.Sequential(
            nn.Conv2d(img_ch, 32, 4, 2, 1),  # 32x32
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),      # 16x16
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1),     # 8x8
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1),    # 4x4
            nn.ReLU(True),
        )
        self.enc_fc = nn.Linear(256 * 4 * 4, 512)
        self.mu = nn.Linear(512, z_dim)
        self.logvar = nn.Linear(512, z_dim)

        # Decoder: z -> (256,4,4) -> (3,64,64) logits
        self.dec_fc = nn.Sequential(
            nn.Linear(z_dim, 512),
            nn.ReLU(True),
            nn.Linear(512, 256 * 4 * 4),
            nn.ReLU(True),
        )
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 8x8
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 16x16
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 32x32
            nn.ReLU(True),
            nn.ConvTranspose2d(32, img_ch, 4, 2, 1) # 64x64 logits
        )

    def encode(self, x):
        h = self.enc(x)
        h = h.view(h.size(0), -1)
        h = F.relu(self.enc_fc(h))
        mu = self.mu(h)
        logvar = self.logvar(h)
        return mu, logvar

    def reparam(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + std * eps
        return mu

    def decode_logits(self, z):
        h = self.dec_fc(z)
        h = h.view(h.size(0), 256, 4, 4)
        logits = self.dec(h)
        return logits

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparam(mu, logvar)
        logits = self.decode_logits(z)
        return logits, mu, logvar

    def loss(self, x_masked, x_target, mask, beta=1.0):
        logits, mu, logvar = self.forward(x_masked)

        bce = F.binary_cross_entropy_with_logits(logits, x_target, reduction="none")  # (B,3,H,W)
        miss = (1.0 - mask)  # 欠損部分=1

        # 欠損部分のrecon
        denom_m = torch.sum(miss, dim=(1,2,3)).clamp(min=1.0)
        recon_miss = torch.sum(bce * miss, dim=(1,2,3)) / denom_m

        # 全体のrecon(弱めに入れて“顔らしさ”を保つ)
        recon_all = torch.mean(torch.sum(bce, dim=(1,2,3)))

        lam = 0.9  # 欠損重視。0.7〜0.95で調整
        recon = torch.mean(lam * recon_miss + (1.0 - lam) * recon_all)

        # KL
        kl = -0.5 * torch.mean(torch.sum(1 + logvar - mu**2 - torch.exp(logvar), dim=1))

        loss = recon + beta * kl
        return loss, recon, kl

4. 学習

欠損補完タスクとして Conv-VAE を学習する。

学習時は、

  • 入力:欠損画像(masked image)
  • 教師信号:元の画像(original image)
  • 目的:欠損部分を自然に復元できる潜在表現を学習する
    という設定で最適化を行う。

損失関数
損失は以下の2項から構成される。

  • 再構成誤差(Binary Cross Entropy)
    • 欠損部分を中心に評価
    • ただし全体の整合性を保つため、全画素の誤差も弱く混ぜる
  • KLダイバージェンス
    • 潜在変数が標準正規分布に近づくよう正則化

検証
各エポック終了時に validation データで

  • loss
  • 再構成誤差
  • KL項
    を計算し、学習の進行状況を確認した。
# ---------------------------
# 6 Train
# ---------------------------
z_dim = 128
model = ConvVAE(z_dim=z_dim).to(device)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)

epochs = 10
beta = 0.05

def eval_epoch():
    model.eval()
    losses, recons, kls = [], [], []
    with torch.no_grad():
        for x_masked, x, mask in valid_loader:
            x_masked = x_masked.to(device, non_blocking=True)
            x = x.to(device, non_blocking=True)
            loss, recon, kl = model.loss(x_masked, x, mask.to(device), beta=beta)
            losses.append(loss.item())
            recons.append(recon.item())
            kls.append(kl.item())
    return float(np.mean(losses)), float(np.mean(recons)), float(np.mean(kls))

for ep in range(1, epochs+1):
    model.train()
    pbar = tqdm(train_loader, desc=f"epoch {ep}/{epochs}")
    losses, recons, kls = [], [], []
    for x_masked, x, mask in pbar:
        x_masked = x_masked.to(device, non_blocking=True)
        x = x.to(device, non_blocking=True)

        opt.zero_grad(set_to_none=True)
        loss, recon, kl = model.loss(x_masked, x, mask.to(device), beta=beta)
        loss.backward()
        opt.step()

        losses.append(loss.item())
        recons.append(recon.item())
        kls.append(kl.item())
        pbar.set_postfix(loss=np.mean(losses), recon=np.mean(recons), kl=np.mean(kls))

    v_loss, v_recon, v_kl = eval_epoch()
    print(f"[VALID] loss={v_loss:.3f} recon={v_recon:.3f} kl={v_kl:.3f}")

5. 修復結果の可視化

学習後、欠損補完の結果を視覚的に確認するため、

  • original(元画像)
  • masked(欠損画像)
  • recon(補完結果)
    を横並びで表示した。
# ---------------------------
# 7 Visualize inpainting results (3列: original / masked / recon)
# ---------------------------
import matplotlib.pyplot as plt

def to_img(t):
    # (3,H,W) tensor [0,1] -> (H,W,3) np
    t = t.detach().cpu().clamp(0,1)
    return np.transpose(t.numpy(), (1,2,0))

model.eval()
x_masked, x, mask = next(iter(valid_loader))
x_masked = x_masked.to(device)
x = x.to(device)

with torch.no_grad():
    logits, _, _ = model(x_masked)
    recon = torch.sigmoid(logits)
    recon = recon * (1.0 - mask.to(device)) + x_masked * mask.to(device)

n_show = 8
plt.figure(figsize=(12, 5))
for i in range(n_show):
    # original
    plt.subplot(3, n_show, 1 + i)
    plt.imshow(to_img(x[i]))
    plt.axis("off")
    if i == 0: plt.title("original")

    # masked
    plt.subplot(3, n_show, 1 + n_show + i)
    plt.imshow(to_img(x_masked[i]))
    plt.axis("off")
    if i == 0: plt.title("masked")

    # recon
    plt.subplot(3, n_show, 1 + 2*n_show + i)
    plt.imshow(to_img(recon[i]))
    plt.axis("off")
    if i == 0: plt.title("recon")
plt.tight_layout()
plt.show()

※ 表示結果
スクリーンショット 2026-01-15 0.35.04.png

最後に

ハマったポイントと対処です。

ハマり①:recon がすべて「平均顔」になる
原因

  • 欠損領域が大きい
  • VAEは尤度最大化のため、安全な平均解を出しやすい

対処

  • β を小さくする(例:beta = 0.01)
  • 潜在次元を大きくする(例:z_dim = 128)
    👉 VAEとしては正常な挙動

ハマり②:目がたくさんある変な顔になる
原因

  • 再構成誤差を「欠損部分のみ」にすると自由度が高すぎる
  • テクスチャだけで loss を下げる局所解に陥る
    対処:欠損部+全体を混ぜる
recon = λ * recon_missing + (1-λ) * recon_all
lam = 0.9

欠損部分を重視しつつ、全体の「顔らしさ」も維持した。

可視化が不自然
原因

  • マスク外まで再構成結果を表示していた

対処(表示用)

recon = recon * (1-mask) + x_masked * mask

結果と所感

  • 欠損補完として自然な顔が生成されるようになった
  • VAEは「曖昧な部分を平均化する」性質がはっきり観察できた
  • 損失設計が挙動に大きく影響することを実感した

参考文献、リンク

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?