LoginSignup
32
18

More than 3 years have passed since last update.

Visual Feature Attribution using Wasserstein GANs - GANでクラス間のピクセルの違いを抽出 -

Last updated at Posted at 2019-08-11

今回はCVPR2018から判断根拠系の論文です。とうとうこの分野にもGANが入ってきました。論文リンクはこちら

よくあるニューラルネットワークの判断根拠は「ある分類・回帰モデルがどの入力を見て出力を得ているか?」に近いのに対し,この研究は「データの中でラベルXXX足らしめている特徴は何か?」を表す手法になります。正確にいうとモデルの判断根拠ではなく,データ内の判断に使われるであろう特徴を抽出する手法となります。

医療分野の組織なので,実験はアルツハイマーを見分ける脳画像です。

概要

  • 画像の二値分類において,そのクラスで重要な画像領域をGANを用いて抽出する手法VA-GANの提案
  • クラスAの画像にGeneraterから作成したマスクを足し合わせた画像を,DiscriminaterがクラスBの画像と比較して,Generaterのマスクがかかっているか否かを学習する
  • GeneraterはクラスBに変更するマスク,すなわちそのクラス特有の特徴を表すようなマスクを生成するようになる

どういうことができるか?

image.png

ガウシアンノイズで作成した画像と,それに四角形の領域(画像上)を足した画像を用意し,それぞれラベル0, ラベル1というラベルをつけておきます。
この画像に対して,判断根拠的な手法と提案手法を用いて,ラベル1を見分けるのに重要なピクセルを抽出する実験を行います。

image.png

実験結果がこちら。左の3つが既存手法,一番右のVA-GANが提案手法です。既存手法ではノイズが混じっていたり,輪郭しか出せなかったり,加えたマスクをぼやっとした領域でしか出せていませんが,提案手法は領域をはっきり抽出できています。

どうやってやってるの?

ほぼこの図で説明されます。
image.png

以下のフローで学習を行います。

  1. あるクラス1の画像$X$をGenerater $M$に入力する
  2. $M$の出力$M(X)$と画像$X$の和をとり,画像$Y$を出力する
  3. Discriminator $D$は$X$とは別のクラス(ここでは0)の画像$X'$と,2.で出力された画像$Y$を比較し,それが本物の画像か生成された画像か判別する

直感的に,$M$はクラス1の画像をクラス0に変更するマスクを生成するように学習することとなります。すなわち,そのマスクはクラス0とクラス1の異なる部分であることとなります。

学習方法

ロス関数は以下の式で表されます。


一番上は本物の入力に対して$D$が0と判断した確率(つまり正解した確率)からマスクされた画像を1と判断した確率(つまり間違った確率)を引いたものです。

二番目は正則化項です。これは画像そのものを変更してしまうのを防ぐために,マスクをある程度スパースにして変更箇所を限定するために入れているようです。

三番目は全体のロス関数です。

実験

アルツハイマーかそうでないかを二値分類する脳画像データセットを用いて,アルツハイマーの原因となる箇所(脳の縮小が起きるので脳のエッジ部分)を抽出できるか可視化を行いました。

image.png

ObservedがGrand Truthです。他の手法と比べて提案手法は実際に縮小した部分を表現できています。他の手法は四角を与えるプレ実験のように,ノイズが多かったりぼやけていたりしてうまく抽出できてないですね。

Grand Truthと抽出したマップのNCC(いわゆる正規化された画像の輝度の内積)を出した結果は以下です。

image.png

提案手法が一番高い精度を出せています。

まとめと所感

モデルがどう判断した,ではなくデータ自体がクラス間でどのような異なる性質を持っているかをニューラルネットで示した論文は意外と少ない気がしています。こっち系の論文も増えるとまた一段難しいタスクにも適用できそうですね!

32
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
32
18