13
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PRML第8章 積和アルゴリズム Python実装

Posted at

第8章ではグラフィカルモデルについて説明されています。グラフィカルモデルとは、確率変数やモデルのパラメータなどの関係を図式的に表現する方法です。PRMLに載っている線形回帰などのモデルはだいたいグラフィカルモデルを用いて表すことができます。
グラフィカルモデルで表現可能なモデルに適応できる積和アルゴリズムを実装して、画像ノイズを除去してみました。この手法はPRML第13章で紹介されている隠れマルコフモデルなどにも適用されている手法の一般的な形式のようです。

画像ノイズ除去

ここでは画像中のピクセルが-1,1のどちらかの値になっている二値画像(白:1、黒:-1)を考えます。
下のような幾つかのピクセルで値が反転している画像が与えられたときに、

次のようなノイズなしの画像を復元します。

マルコフ確率場

画像ノイズ除去を行うためのモデルとして、グラフィカルモデルの一種であるマルコフ確率場(マルコフネットワーク、無向グラフィカルモデル)を使います。

下の図が今回考えるマルコフ確率場を模式的に表したものです。

$y$のノードはノイズ付きの我々が観測した画像のピクセル値を表す確率変数、$x$のノードは復元したいノイズなしの画像のピクセル値を表す確率変数としています。ノイズなしの画像では隣り合うピクセルには強い相関があると考えられるので、このモデルでも$x$では隣り合うノード同士は線で結ばれています。またノイズの割合も比較的小さいので、ノイズなし画像でのあるピクセルの値とノイズ付き画像での対応するピクセル値も強い相関があります。
例えば、ノイズなしの画像であるピクセルの値が1であれば、そのピクセルに隣接しているピクセルは1である可能性が高く、ノイズ付き画像での対応するピクセル値も1である可能性が高い。ようは線で結ばれているノードは同じピクセル値になりやすいということです。

これを数式で表すと、

p({\bf x},{\bf y}) = {1\over Z}\exp\left\{-E({\bf x},{\bf y})\right\}

ただし、ノードを表すインデックスとして$i,j$を用いて、

E({\bf x},{\bf y}) = -\alpha\sum_ix_iy_i-\beta\sum_{隣接しているiとj}x_ix_j

で、$Z$は正規化定数です。

ノイズ付きの画像が与えられれば$y_i$は具体的な観測値となるのでそれらを代入すると$p({\bf x}|{\bf y})$が得られます。この$p({\bf x}|{\bf y})$の値が最大となる${\bf x}$を復元画像とします。${\bf x}$が取りうる全状態数は$2^{ノードの数}$なので全ての${\bf x}$のパターンについて試すのは現実的ではありません。復元画像を推定する手法として反復条件付きモードというアルゴリズムもありますが、ここではより良い解が得られるらしい積和アルゴリズムをつかって復元します。

積和アルゴリズム

このアルゴリズムはグラフィカルモデル上の確率変数ノードにおいて、そのノードが取りうる値の確率を計算する手法です。線で繋がっているノードからメッセージと呼ばれる正規化していない確率のようなものをやり取りすることで確率を計算します。

$y_i$から$x_i$へのメッセージ

m_{y_i\to x_i} = \exp(\alpha x_i y_i)

ただし、ここでの$y_i$は確率変数ではなく観測値

$x_j$から$x_i$へのメッセージ

m_{x_j\to x_i} = \sum_{x_j}\exp(\beta x_ix_j)f(x_j)

ただし、$f(x_j)$はノード$j$に隣接するノード$i$以外からのメッセージの積。
$f(x_j)$を計算するには、他のノードからノード$i$へのメッセージが必要ですが、今回扱うグラフィカルモデルはループのあるモデルなので、まずメッセージをある値で初期化してから、上記二つのメッセージを送って${\bf x}$を推定します。

実装

sum_product.py
import itertools
import numpy as np
from scipy.misc import imread, imsave


ORIGINAL_IMAGE = "qiita_binary.png"
NOISY_IMAGE = "qiita_noise.png"
DENOISED_IMAGE = "qiita_denoised.png"


class Node(object):

    def __init__(self):
        self.neighbors = []
        self.messages = {}
        self.prob = None

        self.alpha = 10.
        self.beta = 5.

    def add_neighbor(self, node):
        """
        add neighboring node

        Parameters
        ----------
        node : Node
            neighboring node
        """
        self.neighbors.append(node)

    def get_neighbors(self):
        """
        get neighbor nodes

        Returns
        -------
        neighbors : list
            list containing neighbor nodes
        """
        return self.neighbors

    def init_messeges(self):
        """
        initialize messages from neighbor nodes
        """
        for neighbor in self.neighbors:
            self.messages[neighbor] = np.ones(shape=(2,)) * 0.5

    def marginalize(self):
        """
        calculate probability
        """
        prob = reduce(lambda x, y: x * y, self.messages.values())
        self.prob = prob / prob.sum()

    def send_message_to(self, node):
        """
        calculate message to be sent to the node

        Parameters
        ----------
        node : Node
            node to send computed message

        Returns
        -------
        message : np.ndarray (2,)
            message to be sent to the node
        """
        message_from_neighbors = reduce(lambda x, y: x * y, self.messages.values()) / self.messages[node]
        F = np.exp(self.beta * (2 * np.eye(2) - 1))
        message = F.dot(message_from_neighbors)
        node.messages[self] = message / message.sum()

    def likelihood(self, value):
        """
        calculate likelihood via observation, which is messege to this node

        Parameters
        ----------
        value : int
            observed value -1 or 1
        """
        assert (value == -1) or (value == 1), "{} is not 1 or -1".format(value)
        message = np.exp(self.alpha * np.array([-value, value]))
        self.messages[self] = message / message.sum()


class MarkovRandomField(object):

    def __init__(self):
        self.nodes = {}

    def add_node(self, location):
        """
        add a new node at the location

        Parameters
        ----------
        location : tuple
            key to access the node
        """
        self.nodes[location] = Node()

    def get_node(self, location):
        """
        get the node at the location

        Parameters
        ----------
        location : tuple
            key to access the corresponding node

        Returns
        -------
        node : Node
            the node at the location
        """
        return self.nodes[location]

    def add_edge(self, key1, key2):
        """
        add edge between nodes corresponding to key1 and key2

        Parameters
        ----------
        key1 : tuple
            The key to access one of the nodes
        key2 : tuple
            The key to access the other node.
        """
        self.nodes[key1].add_neighbor(self.nodes[key2])
        self.nodes[key2].add_neighbor(self.nodes[key1])

    def sum_product_algorithm(self, iter_max=10):
        """
        Perform sum product algorithm
        1. initialize messages
        2. send messages from each node to neighboring nodes
        3. calculate probabilities using the messages

        Parameters
        ----------
        iter_max : int
            number of maximum iteration
        """
        for node in self.nodes.values():
            node.init_messeges()

        for i in xrange(iter_max):
            print i
            for node in self.nodes.values():
                for neighbor in node.get_neighbors():
                    node.send_message_to(neighbor)

        for node in self.nodes.values():
            node.marginalize()


def denoise(img, n_iter=20):
    mrf = MarkovRandomField()
    len_x, len_y = img.shape
    X = range(len_x)
    Y = range(len_y)

    for location in itertools.product(X, Y):
        mrf.add_node(location)

    for x, y in itertools.product(X, Y):
        for dx, dy in itertools.permutations(range(2), 2):
            try:
                mrf.add_edge((x, y), (x + dx, y + dy))
            except Exception:
                pass

    for location in itertools.product(X, Y):
        node = mrf.get_node(location)
        node.likelihood(img[location])

    mrf.sum_product_algorithm(n_iter)

    denoised = np.zeros_like(img)
    for location in itertools.product(X, Y):
        node = mrf.get_node(location)
        denoised[location] = 2 * np.argmax(node.prob) - 1

    return denoised


def main():
    img_original = 2 * (imread(ORIGINAL_IMAGE) / 255).astype(np.int) - 1
    img_noise = 2 * (imread(NOISY_IMAGE) / 255).astype(np.int) - 1

    img_denoised = denoise(img_noise, 10)

    print "error rate before"
    print np.sum((img_original != img_noise).astype(np.float)) / img_noise.size
    print "error rate after"
    print np.sum((img_denoised != img_original).astype(np.float)) / img_noise.size
    imsave(DENOISED_IMAGE, (img_denoised + 1) / 2 * 255)

if __name__ == '__main__':
    main()

結果

ノイズ除去して復元した画像

ターミナルの出力結果

error rate before
0.109477124183
error rate after
0.0316993464052

元の画像と比較したときの誤差率が減っています。

終わりに

ノイズ除去を目的とするのであればグラフカットという手法が一番精度が良いらしいです。

13
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?