2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

UNet(系)で、マルチクラスセグメンテーション(2)

Last updated at Posted at 2018-09-03

はじめに

この記事は、前記事からの続きです。
詳しい話は、そちらを。

改良 (1)

ということで、惜しい感じにはなったが、
今回の条件では、2クラスしか対象にしておらず、そこに対して、RGBの3channelを出力しようとしており、どのクラスにも対応していない(=対象物がない)G channelが、うまく学習できなかったんじゃないだろうか

2クラスでも、この2クラスでRGBの全channelを使う様な設定
(例えば、1クラス目を(255, 255, 0)、2クラス目を(0, 255, 255)にするなど)
すれば、うまく行くのかもしれないが、なんかそれは、嫌だった(ほんとになんとなく、負けた気がして。)
RGBの3色出力にするとして、クラス数が3以上でも、使わない色素があるかもしれないじゃないか!
という思いもあった。(たぶん)

ということで、
なんとかならないものかと思い
まず考えたのが、そもそも全体が0だと、領域計算ができないので、lossが(ちゃんと)計算できないじゃん
ということで、無視できる値を決めて、教師データの全体が0のchannelに対して、その無視できる値で埋めてしまって学習してみたらどうだろうか。
出力時に、その無視できる値以下の点を0に抑えてしまえばいけそうな気がする。

学習 (1)

前回からの変更点は、
教師画像に対して、NNへの入力前に、全体が0のchannelに、無視できる値(10(/255))を入れて学習する様にした点。
それ以外は、同じにして学習。

結果 (1)

loss曲線は、こんな感じ
UNet:
LearningCurve_loss_2.png

DeepUNet:
(またもや紛失...すみません)

そして、出力結果はこんな感じ
(一部抜粋。各チャンネル10以下の出力は、0に変換。)
Cond_2_Train.png
Cond_2_Test.png

元画像の形が、Gチャンネルに残ることはなくなったが、
無視できる量に学習できていない。

結局、対象物がいないchannelは、全体が(前回は、0, 今回は、10(/255)で)フラット(全体に対して値の差がない)なので、
それがいけないのかもしれない。
(後から思えば、もしかしたら、教師画像のGチャンネルに無視できる値を入れたつもりで失敗して大きな値を入れてしまっていたのかもしれないが...)

考えてみれば、dice lossを使用しているが、(この計算方法がいけないかもしれないが)
領域そのものが重なっていなくても、領域範囲が一致していて、トータル値が一致してしまえば、lossが一致しそう。

実際、以下の様に確認すると

dice_coef_loss_chk.py
def dice_coef(y_true, y_pred):
    y_true = KB.flatten(y_true)
    y_pred = KB.flatten(y_pred)
    intersection = KB.sum(y_true * y_pred)
    denominator = KB.sum(y_true) + KB.sum(y_pred)
    if denominator == 0:
        return 1
    if intersection == 0:
        return 1 / (denominator + 1)
    return (2.0 * intersection) / denominator


def pt(t):
    val = KB.get_value(t)
    print(val)

gt = KB.cast(tf.constant([1,1,1,1,0,0,0,0]), 'float32')
a = KB.cast(tf.constant([1,1,1,1,0,0,0,0]), 'float32')
b = KB.cast(tf.constant([0,0,2,2,0,0,0,0]), 'float32')
c = KB.cast(tf.constant([1,2,0,1,0,0,0,0]), 'float32')


a_loss = 1 - dice_coef(gt, a)
b_loss = 1 - dice_coef(gt, b)
c_loss = 1 - dice_coef(gt, c)

pt(a_loss)
pt(b_loss)
pt(c_loss)

出力↓

a_loss :  0.0
b_loss :  0.0
c_loss :  0.0

ってな感じで一致してしまう。

(つぎ記事へつづく。。。)

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?