0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【PyTorch】デノイジングCNN:DCT2Net を解説・実装

Last updated at Posted at 2024-08-01

はじめまして,naoki2902 と申します.私は現在大学院において,コンピュテーショナルイメージング(Computational Imaging: CI)について研究をしています.今回は,CI においても重要な課題である「デノイジング」について興味深い論文を見つけたので,解説・実装していきたいと思います.
なお,今回が初めての Qiita 記事投稿となります.

紹介論文

概要と新規性

  • DCT denoiser を「浅い CNN」とみなした DCT2Net を提案した.

  • DCT2Net の結果と従来の DCT denoiser の結果を組み合わせることで,DCT denoiser の性能を上回るデノイジングを実現した.

  • 従来の CNN を用いた手法とは異なり,CNN 内部の構造が「ブラックボックス」ではなくなり,解釈が可能になった.

GitHub

1. はじめに

1-1. デノイジング

CI や画像処理の領域では,画像 $x\in\mathbb{R}^{N\times N}$ を取得したいときにノイズ $\varepsilon\in\mathbb{R}^{N\times N}$ が加わった画像 $y=x+\varepsilon\in\mathbb{R}^{N\times N}$ しか得ることのできない場面が多く存在します.このようなとき,画像 $y$ から画像 $x$ を復元するために用いられる手法が「デノイジング」と呼ばれています.

1-2. 先行研究

1-2-1. DCT Denoiser[1]

離散コサイン変換(Discrete Cosine Transform: DCT)を用いた非学習デノイジング手法です.本手法は単純かつ高速であるという利点がある一方で,画像 $x$ の微細な構造が失われてしまうという欠点も挙げられます.手法の詳細は2-1において説明します.

1-2-2. CNN

畳み込みニューラルネットワーク(Convolutional Neural Networks: CNN)の利用はコンピュータビジョンの領域で近年急速に拡大しています.デノイジングもその例外でなく,さまざまな手法が提案されているようです.

1-3. 実験環境

  • チップ:Apple M1 Pro
  • メモリ:16 GB
  • macOS:Sonoma 14.5
  • Python:3.10.10
  • PyTorch:1.11.0

2. DCT Denoiser

まずは,本論文の前提である DCT denoiser について解説・実装します.

2-1. 手法

DCT denoiser において,画像 $y\in\mathbb{R}^{N\times N}$ のピクセル $k$ におけるデノイズ結果 $F(y)_k\in\mathbb{R}^{N\times N}$ は次の式によって表されます[2]:

F(y)_k=\frac{1}{W_k}\sum_{i=-q}^{q}\sum_{j=-q}^{q}w_{i,j,k}[P\varphi_{\lambda}(P^{-1}y_{k,p}^{i,j})]_{s(i, j)}

ただし,$p\in\mathbb{N}$ はパッチサイズ,$i, j\in\mathbb{Z}$ は $p\times p$ のパッチ内における座標(パッチ中心で $i=j=0$),$q=\lfloor p/2 \rfloor\in\mathbb{N}$ です.$y_{k,p}^{i,j}\in\mathbb{R}^{p\times p}$ は,$y$ の $k$ 番目のピクセルについて座標 $(i,j)$ を中心に $p\times p$ の領域をとってきたものです.$P\in\mathbb{R}^{p^2\times p^2}$ は DCT に関する基底行列で,

P_{i=xp+y+1,j=up+v+1}=\frac{2}{p}\alpha(u)\alpha(v)\cos\left[\frac{(2x+1)u\pi}{2p}\right]\cos\left[\frac{(2y+1)v\pi}{2p}\right]
\alpha(u)=\begin{cases}1/\sqrt{2}\,\,(u=0)\\1\,\,(\text{otherwise})\end{cases}

と書くことができます.また,$\varphi_\lambda(x)$ は Hard shrinkage 関数であり,

\varphi_\lambda(x)=\begin{cases}0\,\,(x\in [-\lambda,\lambda])\\ x\,\,(\text{otherwise})\end{cases}

と表されます.$s(i,j)=(q-i)p+q-j+1$ であり,$w_{i,j,k}$ と乗算したい座標を表しています.さらに,重み $w_{i,j,k}, W_k$ はそれぞれ

w_{i,j,k}=1/(1+\|\varphi_\lambda(P^{-1}y_{k,p}^{i,j})\|_0)\in\mathbb{R}
W_k=\sum_{i=-q}^{q}\sum_{j=-q}^{q}w_{i,j,k}\in\mathbb{R}

です.なお,この $F(\cdot)$ は従来の DCT denoiser[1] とは異なる式で表されていることに注意してください.

以上のように DCT denoiser の式を書いてみたのですが,さすがにこの式だけだと解釈が非常に難しいので,それぞれの項で何が行われているかのイメージをお伝えします.画像 $y$ をパッチサイズにしたがって展開します.その後,行列 $P^{-1}$ によって DCT 変換をします.変換後の空間において,Hard shrinkage 関数 $\varphi_\lambda$ によってノイズに関する周波数領域を切り落とします.そして,行列 $P$ によって実空間に戻し,重み $w$ によって各パッチの平均を取ります.イメージできたでしょうか?

2-2. 実装

2-1 で定式化した DCT denoiser を Python および PyTorch を使って実装してみます.
DCT denoiser はクラスにより表現します.これにより,各パラメータを体系的に保存しながらシミュレーションを行うことができます.実際に記述したコードを以下に記載します.

dct.py
import numpy as np
import torch
import torch.nn as nn

def alpha(u):
    if u == 0:
        return 1 / np.sqrt(2)
    else:
        return 1

class DCTDenoiser(nn.Module):

    def __init__(self, patch_size, threshold):

        super(DCTDenoiser, self).__init__()

        self.patch_size = patch_size
        self.threshold = threshold

        P = np.zeros((patch_size**2, patch_size**2))
        for x in range(patch_size):
            for y in range(patch_size):
                for u in range(patch_size):
                    for v in range(patch_size):
                        P[x*patch_size+y, u*patch_size+v] += 2 / patch_size * alpha(u) * alpha(v) * np.cos((2*x+1)*u*np.pi/(2*patch_size)) * np.cos((2*y+1)*v*np.pi/(2*patch_size))
        P = P.astype(np.float32)
        self.P = nn.Parameter(torch.from_numpy(P))
        self.Pinv = nn.Parameter(torch.from_numpy(P.T))
        self.P.requires_grad_(False)
        self.Pinv.requires_grad_(False)

        self.unfold = nn.Unfold(kernel_size=patch_size)
        self.shrink = nn.Hardshrink(threshold)

    def forward(self, y):

        _, _, height, width = y.shape
        self.fold = nn.Fold(output_size=(height, width), kernel_size=self.patch_size)

        y = self.unfold(y)
        y = self.Pinv @ y
        y = self.shrink(y)
        w = 1 / (1 + torch.count_nonzero(y, dim=1))
        y = self.P @ y
        y *= w
        y = self.fold(y)
        
        W = self.fold(w.unsqueeze(1).repeat(1, self.patch_size**2, 1))
        output = y / W

        return output

コードの解説をしていきます.self.unfold で画像 y を展開します.次に,self.Pinv @ y で DCT 変換をします.続いて,self.shrink(y) で Hard shrinkage 関数により帯域制限を課した後に,self.P @ y で実空間に戻してあげます.最後に,重み w をもとに折りたたむ(平均をとる:self.fold)ことで,デノイジング結果である output を得ています.2-1で説明した通りの順番ですね.

2-3. 結果

2-2で作成した DCTDenoiser を用いて実際にデノイジングを行ってみます.なお,今回は簡単のために図1の飛行機の画像に絞って結果を示します.
各パラメータの値は次の通りです.

  • 画像サイズ:$224 \times 224$
  • ガウシアンノイズの標準偏差 $\sigma$:$25$
  • Hard shrinkage 関数の閾値 $\lambda$:$3\times \sigma$
  • パッチサイズ $p$:$13$

infer_dct.py を実行した結果を下に示します.

図3より,DCT denoiser によってノイズが除去されていることが定性的にも定量的にもわかります.一方で,飛行機まわりの微細な構造が消えてしまっていることもわかります(飛行機の横のロゴやプロペラなど).この点が DCT denoiser の欠点として広く知られています.

3. DCT2Net

続いて,本論文の主題である DCT2Net について解説・実装をします.

3-1. 手法

DCT2Net では,図2に示した DCT denoiser を「浅い CNN」とみなします(図4).具体的には,DCT 変換を「畳み込み層」,Hard shrinkage 関数を「活性化関数」とみなします.そして,DCT 変換行列 $P$ を初期値として重みを学習します.

ここで,Hard shrink 関数 $\varphi_\lambda(x)$ は図5に示す通り「微分不可能な関数」であり,誤差逆伝播に支障が出てしまいます.したがって,微分可能な関数に近似する必要があります.新たな関数を $\zeta_{m,\lambda}(x)$ とすると,

\zeta_{m,\lambda}(x)=\frac{x^{2m}}{x^{2m}+\lambda^{2m}}\cdot x

と表せます.実際,図5をみると,良い精度で Hard shrinkage 関数を近似できていることがわかります.

3-2. 実装

DCT2Net について Python および PyTorch を使って実装してみます.実際に記述したコードを以下に記載します(一部省略).

dct2net.py
import torch.nn.functional as F

class DCT2Net(nn.Module):

    def __init__(self, patch_size, threshold):

        super(DCT2Net, self).__init__()

        self.patch_size = patch_size
        self.threshold = threshold

        P = np.zeros((patch_size**2, patch_size**2))
        for x in range(patch_size):
            for y in range(patch_size):
                for u in range(patch_size):
                    for v in range(patch_size):
                        P[x*patch_size+y, u*patch_size+v] += 2 / patch_size * alpha(u) * alpha(v) * np.cos((2*x+1)*u*np.pi/(2*patch_size)) * np.cos((2*y+1)*v*np.pi/(2*patch_size))
        P = P.astype(np.float32)

        filter = torch.zeros(patch_size**2, 1, patch_size, patch_size)
        for k in range(patch_size**2):
            filter[k, 0, k // patch_size, k % patch_size] = 1.0

        self.conv1 = nn.Conv2d(in_channels=1, 
                               out_channels=patch_size**2, 
                               kernel_size=patch_size, 
                               bias=False)
        self.conv1.weight = nn.Parameter(torch.from_numpy(P.T).view(patch_size**2, 1, patch_size, patch_size)) # based on P.T

        self.conv3 = nn.ConvTranspose2d(in_channels=patch_size**2, 
                                        out_channels=1, 
                                        kernel_size=patch_size, 
                                        bias=False)
        self.conv3.weight = nn.Parameter(filter) # filter
        self.conv3.requires_grad_(False)

        self.conv4 = nn.ConvTranspose2d(in_channels=1, 
                               out_channels=1, 
                               kernel_size=patch_size, 
                               bias=False)
        self.conv4.weight = nn.Parameter(torch.ones(1, 1, patch_size, patch_size)) # average
        self.conv4.requires_grad_(False)

        self.shrink = DifferentiableThreshold.apply

    def forward(self, y):

        y_conv1 = self.conv1(y)
        shrink = self.shrink(y_conv1 / (3 * self.threshold))
        y_shrink = y_conv1 * shrink
        w = 1 / (1 + torch.sum(shrink, dim=1))
        w = w.unsqueeze(1)

        y_conv2 = F.conv2d(y_shrink, torch.inverse(self.conv1.weight.view(self.patch_size**2, self.patch_size**2)).view(self.patch_size**2, self.patch_size**2, 1, 1))
        y_conv3 = self.conv3(y_conv2 * w)

        w_conv4 = self.conv4(w)
        out = y_conv3 / w_conv4

        return out

コードの説明をしていきます.self.conv1 により,入力画像 y を展開しつつ DCT 変換しています.ただし,この畳み込み層の重みは徐々に更新されていくので,DCT 変換となっているのは最初のイテレーションのみとなることに注意してください.次に,活性化関数として shrink を乗算します.ここで,活性化関数 $\zeta_{\lambda,m}(x)$ は PyTorch にモジュール化されていないので,自分で関数を記述する必要があります.以下のコードによって $\zeta_{\lambda,m}(x)$ を定義しました.

class DifferentiableThreshold(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):

        ctx.save_for_backward(x)
        m = 32

        # 近似する & nanを防ぐ
        th = 1.5
        mask = torch.abs(x) >= th

        y = torch.ones_like(x)
        z = torch.pow(x, 2*m)
        y[~mask] = z[~mask] / (z[~mask] + 1)

        return y
    
    @staticmethod
    def backward(ctx, dL_dy):

        x, = ctx.saved_tensors
        m = 32

        # 近似する & nanを防ぐ
        th = 1.5
        mask = torch.abs(x) >= th

        dy_dx = torch.zeros_like(x)
        z = torch.pow(x, 2*m-1)
        dy_dx[~mask] = 2 * m * z[~mask] / (x[~mask] * z[~mask] + 1) ** 2
        dL_dx = dL_dy * dy_dx

        return dL_dx

ここで,図5から近似精度が高いと思われる部分は $\varphi_{\lambda}(x)$ に近似しています.これをしないと x が大きいときに nan が出力されてしまいます.
さて,DCT2Net に戻ります.F.conv2d により,行列を実空間に戻してあげます.このとき,この畳み込み層の重みは self.conv1 のものから計算されることに注意する必要があります.最後に,重み w とともに self.conv3 によって折りたたみます.重みの平均 Wself.conv4 から求められます.

3-3. 結果

3-3-1. 学習

まずは,3-2で作成した DCT2Net の重みの学習を行います.使用したデータセットは,Berkley segmentation dataset です.このデータセットには風景や動物のカラー写真が約400枚保存されています.今回は,白黒画像に変換した画像をそのまま使って学習を行います.反転や回転などの Data augumentaion は行っていません.
主なパラメータの値は次の通りです.

  • 画像サイズ:$224\times 224$
  • ガウシアンノイズの標準偏差 $\sigma$:$25$
  • Shrinkage 関数の閾値 $\lambda$:$3\times \sigma$
  • パッチサイズ $p$:$13$
  • バッチサイズ:$32$
  • エポック数:$15$

また,誤差関数は nn.MSELoss を使用しました.
学習は train.py により行いました.学習時間は CPU を使用して約45分ほどでした.GPU を使用すればかなり早く学習できそうです.エポックごとのロスカーブは図6のようになったので,学習はうまくいっていそうです.

3-3-2. 学習結果

3-3-2-1. デノイジング結果

学習前後それぞれについてデノイジング結果を求めてみました.infer_dct2net.py を実行した結果を図7に示します.この図より,学習によってPSNR が上昇したことがわかり,学習による効果を確認することができました.一方で,学習後であっても DCT denoiser を用いた結果には劣っている結果でした.特に,雲の領域にノイズが多く乗ってしまっていて,これが PSNR を大きく下げる要因になっていそうです.また,活性化関数を微分可能なものに変えたことも PNSR の悪化に影響している可能性があります.

3-3-2-2. 学習後重みパラメータ

DCT 変換に対応する畳み込み層の重みパラメータは学習によってどのように変化したのでしょうか?図8に結果を示します.この図を見ると,学習前後で見た目上の変化はほぼないように見えます.しかし,図7のようにデノイジング結果には大きな差が生じているので,微小な重みの変化が重要であることがわかります.

4. Hybrid Denoising

DCT denoiser による結果と DCT2Net による結果はそれぞれ「細かい構造が消えてしまう」「おおまかな領域にノイズが生じてしまう」という問題がありました.そこで,2つのモデルの欠点を相互に補うような新たなモデルがこの論文では提案されています.これをここでは Fusion denoising と呼ぶことにします.

4-1. 手法

細かい構造は「エッジ」と捉えることができます.そこで,DCT denoiser により得られた画像からエッジ部分を推定して,エッジマスクを作成します.エッジでない部分は DCT denoiser の方が得意なので,その部分は DCT denoiser の結果を使用し,エッジ部分は DCT2Net の方が得意なので,その部分は DCT2Net の結果を使用します.これら2つの結果を加算することで,より良い結果を得ることができるとしています.

4-2. 結果

それでは Fusion denoiser を用いたときの結果について見てみましょう.infer_fusion.py を実行した結果を図10に示します.DCT2Net によるノイズが消えていることがはっきりわかります.一方で,エッジ部分が綺麗になっているかどうかは目視では確認することができませんでした.また,PSNR は DCT2Net によるものには優っていましたが,DCT denoiser によるものには劣っている結果となりました.この件について,次節で考察してみます.

5. 考察

2-4節の実験結果を改めてまとめてみます.

Model PSNR
Noise 20.19 dB
DCT denoiser 32.58 dB
DCT2Net (before learning) 28.02 dB
DCT2Net (after learning) 29.62 dB
Fusion denoiser 31.70 dB

上の表から,従来手法である DCT denoiser が最もデノイジングに成功していることがわかります.Fusion denoiser も次点ではありますが,Fusion により DCT denoiser の影響を多く受けているからであり,根本的には DCT2Net の結果の向上が必要です.
それではなぜ DCT2Net の結果は良くなかったのでしょうか?まず考えられるのが,論文中においても Fusion denoiser の PSNR は DCT denoiser の PSNR と比べて 1 dB 程度しか向上していないという点です.そもそも Fusion denoiser は劇的に性能を向上させるモデルではないので,パラメータの設定や学習手法が少しでも不十分であれば,簡単に従来手法の結果を下回ってしまうと考えられます.今回は学習用の画像数がかなり少なかったので,例えば,Data augumentation を行って嵩増しすることで結果が向上する可能性があります.また,誤差関数も非常に単純なものだったので,正則化項を追加するのも性能向上に一役買うと予想されます(論文中でも言及がありました).
また,今回は飛行機の画像に絞っていました.別の画像なら Fusion denoiser の結果が DCT の結果を上回る可能性はあると思います.

6. まとめ

今回は,新たに提案されたデノイジング手法である DCT2Net および Fusion denoiser の解説・実装を行いました.結果としては,提案手法はある程度の性能を示したものの,従来手法である DCT denoiser にやや劣る結果となりました.
ここからは私の感想になりますが,本論文は単純なデノイジング手法を CNN に拡張した点,そして従来手法と組み合わせた点が非常に興味深いと感じました.一方で,提案モデルが果たして CNN と主張できるのか(あまりにも浅いのではないか),あまり性能向上につながっていないのではないかという感想も持ちました.
すべて自分の力で CNN を実装したのは今回が初めてだったのでかなり実装は大変でしたが,そこそこの結果を出すことができて満足しています.最後まで読んでいただきありがとうございました!

参考文献

[1] G. Yu and G. Sapiro, "DCT image denoising: a simple and effective image denoising algorithm," Image Process. OnLine, 1, 292-296, 2011.
[2] N. Pierazzo, J.-M. Morel, and G. Facciolo, “Multi-scale DCT denoising,” Image Process. Line, 7, 288–308, 2017.

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?