ジョブカン事業部アドベントカレンダー24日目です!
はじめに
初めまして!ジョブカン事業部インターン生の@YSK_08です!
私は大学院で顕微鏡の研究をしているのですが、論文調査でDeep Image Prior(DIP) という手法について知ったので、今回ご紹介したいと思います。
DIP は、学習済みモデルを使わず、ネットワーク構造そのものを「事前知識」として利用するという、少し変わったアプローチを取ります。ノイズ除去やインペインティング、超解像、圧縮によるアーチファクトの除去など、色々なタスクに使えますが、今回はノイズ除去に挑戦したいと思います。
この記事では
- DIPの肝となる畳み込みニューラルネットワークについて
- DIPではどうやって画像復元を行うのか
を簡単にお話しした後に、実際に動かしてみようと思います!
畳み込みニューラルネットワークの概要
DIPの仕組みについてお話しする前に、まずは畳み込みニューラルネットワーク(CNN)の学習の仕組みについて簡単に説明したいと思います。CNNとは画像処理で使われるニューラルネットワークです。

https://ex-ture.com/blog/2021/01/11/pytorch-cnn/
CNNは入力層、隠れ層、出力層から構成されます。入力層では画像を受け取り、隠れ層では画像にどんな特徴があるかを分析します。隠れ層は複数の層からなり、第1層では入力$X$に対して、重み$W$をかけることで、特徴量を抽出します。
$$ z = X * W $$
$W$はフィルターのようなものです。以下のような行列を考えた場合、左上から右下への斜め方向の特徴が強調されるといった感じです。
W =
\begin{bmatrix}
1 & 0 & 0 \\
0 & 1 & 0 \\
0 & 0 & 1
\end{bmatrix}
複数の$W$を用いることで、様々な特徴量を得ることが出来ます。例えば、色の濃淡に関するマップや、エッジ、丸みに関するマップなどです。$z$は活性化関数というものに渡され、その結果を次の層への入力とします。
$$ A = \phi(z) $$
この操作を多数の層で行うことで、高次元の特徴量を抽出していきます。最後に出力層でそれらを集計し、出力を決めます。
学習初期では$W$はランダムな値であり、出力も欲しいものとは大きく異なります。そのため、訓練用のデータとの比較を行い、誤差(損失)を求めます。すごく単純な式で示すと以下になりますが、損失関数は扱うデータやタスクによって良い感じに設定する必要があります。
$$ L = ||z^{out} - y||$$
ここで、$z^{out}$は出力、$y$は訓練データになります。
この損失が小さくなるように、今度は出力層->隠れ層->入力層と逆方向に情報を伝達していき、各層の重みを調節します。重みの更新については少し複雑なので、今回は省きたいと思います。ここまでの流れを何度も繰り返すことで、重みが最適化され、欲しい出力が得られるようになります。
Deep Imaging Priorの概要
ここまででCNNにおける学習の仕組みをお話ししました。ニューラルネットワークでは訓練データを使うため、質の高いデータを事前に集めておく必要があります。しかし、そのようなデータを集めるのは大変だと思います。
これに対し、大量のデータを用いたモデルのトレーニングを行わなくても、1枚の画像からでもニューラルネットワークが良い感じにやってくれるよというのがDIPになります。ただ、学習を全く行わないというわけではなく、教師あり学習のような事前学習が不要という点にご注意ください。
DIPでは固定したノイズ$z$を入力とし、ニューラルネットワークによって画像$f_{\theta}(z)$を出力します。そして、ノイズが含まれた画像$x_0$と比較し、逆伝搬を行います。これを何度も繰り返すと最終的には$f_{\theta}(z)\approx x_0$となり、ノイズの含まれた画像が出力されます。しかし、学習を良い感じのところで止めてあげるとノイズが消えた、クリーンな画像を得ることが出来ます。
なぜこんなことが起こるのかというのを実験したのが下図です。$x_0$として、クリーン画像にした場合と、ノイズ画像にした場合とでの、パラメータの更新回数に対する損失の変化を比較したものです。

Ulyanov, D.,et al., Deep Imaging Prior, 2017.
縦軸は損失を表しています。一番下の青線がノイズのないクリーンな画像、一番上の紫がノイズ画像です。青線は学習初期に急速に損失が下がるのに対し、紫線では損失の減少が遅く、学習までに時間がかかることが分かります。つまり、ニューラルネットワークは画像の構造を先に学び、ノイズのような不規則な構造は後からゆっくり学ぶため、ノイズを学ぶ前に学習を止めればノイズ除去が達成できるという仕組みです。
実装
TensorFlowで実装していきたいと思います。ネットワーク構造は論文に書かれているものと大体同じですが、今回は「ノイズが少し無くなっているな」というのが検証できればOKなので、チャネル数は少なくしています。
Network
ネットワークの全体構造です。

Ulyanov, D.,et al., Deep Imaging Prior, 2017.
ダウンサンプリング層はConv+BN+LeakyReLU+Conv+BN+LeakyReLUとなっています。LeakyReLUは以下のような活性化関数で、ReLU関数とは違い、負の値でもわずかに勾配を持つので、勾配消失を防ぎやすいという特徴があります。

Li, Z., et al, cardiGAN: A Generative Adversarial Network Model for Design and Discovery of Multi-Principal Element Alloys, 2022.
ダウンサンプリング自体は、畳み込み層でストライドを2にすることで行っています。ストライドとは畳み込みフィルターを「どれだけ移動させるか」です。
def down_block(self, x, channels):
x = tf.keras.layers.Conv2D(channels, 3, strides=2, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU()(x)
x = tf.keras.layers.Conv2D(channels, 3, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU()(x)
return x
アップサンプリング層はBN+Conv+BN+LeakyReLU+Conv+BN+LeakyReLUで、最後にBilinear補間により空間を拡大しています。転置畳み込みでも同様の処理が出来ますが、結果はやや悪くなるようです。アーチファクトが生じやすいからとかでしょうか。
def up_block(self, x, channels):
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2D(channels, 3, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU()(x)
x = tf.keras.layers.Conv2D(channels, 3, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU()(x)
return x
def upsample(self, x):
return tf.image.resize(
x,
size=(tf.shape(x)[1] * 2, tf.shape(x)[2] * 2),
method="bilinear"
)
スキップ接続層はConv+BN+LeakyReLUになります。単純に結合するのではなく、畳み込み処理を入れてから結合します。
def skip_block(self, x, channels=4):
x = tf.keras.layers.Conv2D(channels, 1, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU()(x)
return x
モデルの構築部分です。層を積み上げていくだけなので、単純ですね。
def build_model(self):
inputs = tf.keras.layers.Input(
shape=(self.img_height, self.img_width, 3)
)
d1 = self.down_block(inputs, 8)
d2 = self.down_block(d1, 16)
d3 = self.down_block(d2, 32)
d4 = self.down_block(d3, 64)
d5 = self.down_block(d4, 128)
s1 = self.skip_block(d1, 4)
s2 = self.skip_block(d2, 4)
s3 = self.skip_block(d3, 4)
s4 = self.skip_block(d4, 4)
x = self.upsample(d5)
x = tf.keras.layers.Concatenate()([x, s4])
x = self.up_block(x, 128)
x = self.upsample(x)
x = tf.keras.layers.Concatenate()([x, s3])
x = self.up_block(x, 64)
x = self.upsample(x)
x = tf.keras.layers.Concatenate()([x, s2])
x = self.up_block(x, 32)
x = self.upsample(x)
x = tf.keras.layers.Concatenate()([x, s1])
x = self.up_block(x, 16)
x = self.upsample(x)
x = self.up_block(x, 8)
outputs = tf.keras.layers.Conv2D(
3, 1, activation="sigmoid"
)(x)
return tf.keras.models.Model(inputs, outputs)
学習
損失関数は平均二乗誤差 (MSE)とします。生成した画像とターゲットの画像の2乗誤差です。
$$ L = ||f_{\theta}(z) - x_0||^2 $$
パラメータ$\theta$の最適化にはAdamというアルゴリズムを使用します。最適化アルゴリズムも色々ありますが、迷ったらAdamでいいと思います。
class Denoiser:
def __init__(self):
pass
def build(self, width, height):
self.img_width = width
self.img_height = height
self.network = Network(img_width=width, img_height=height).get_model()
self.optimizer = tf.keras.optimizers.Adam(1e-2)
@tf.function
def train_step(self, gaussian, target_image):
with tf.GradientTape(persistent=True) as r_tape:
# forward propagation
predict_image = self.network(gaussian, training=True)
loss = tf.reduce_mean(tf.square(tf.squeeze(predict_image) - tf.squeeze(target_image)))
# back propagation
gradients = r_tape.gradient(loss, self.network.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_variables))
return predict_image
def execute(
self,
target_image,
max_epochs,
):
gaussian = tf.random.uniform((1, self.img_height, self.img_width, 3))
for epoch in range(max_epochs):
predict_image = self.train_step(gaussian, target_image)
return predict_image[0].numpy()
実験結果
CNNのパラメータの更新を2000回繰り返した結果です。左から順にノイズのない画像、ホワイトノイズを加えた画像、DIPによるノイズ除去した画像です。
| Clean image | Noisy image | Denoised image |
|---|---|---|
![]() |
![]() |
![]() |
少しきれいになったかなと思います。もっといい感じにしたいところですが、今回はお試しなので、これで良しとします。
きれいになっていく様子をGifにしてみました。最初は芸術的な画像が出力されていますが、徐々に復元されていってますね。

今回はクリーン画像があるので、出力画像とクリーン画像とのPSNRを計算し、パラメータの更新に対する変化をプロットしてみました。ちなみに、PSNRとは画像の劣化度を示す指標で、高いほど劣化が少ないと評価できます。
プロットする前に、「ある段階で出力画像にノイズが乗り始めて、PSNRは下がり始めるのかな」と思いましたが、そうならなかったのが不思議です。もっとパラメータの更新回数を増やせば下がり始めるんでしょうか。
最後に
Deep Imaging Priorという画像復元の手法を紹介しました。対象の画像1枚だけでいろいろな処理ができるのは、やはり面白いポイントだと思います。(試しに動かしてみるというのもやりやすいですし)
本当は超解像やインペインティングにも挑戦してみたかったのですが、時間的な制約で断念しました…
いずれ機会があれば、別のタスクにも挑戦してみたいと思います。
最後になりますが、DONUTSでは新卒中途インターン問わず積極的に採用活動を行っています。
弊社にご興味を持っていただけた方はぜひご応募ください!



