0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

くまンティック・セぐまンテーション の中身を見る

Last updated at Posted at 2020-03-02

クマさんが映ってる画像から、クマさんに関する情報を得る「くまンティック・セぐまンテーション」。前回に続いて、今回は、くまンティック・セぐまンテーションで定義したネットワークが何を見ているかを確認する作業を行ってみました。

クマさんのシルエットを認識するネットワーク

前回定義したやつと同じです。以下、「クマネットワーク」と呼ぶことにします。

import torch
from torch import nn, optim
from torch.nn import functional as F
class Kuma(nn.Module):
    def __init__(self):
        super(Kuma, self).__init__()
        # エンコーダー部分
        self.encode1 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 1, out_channels = 6, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(6)
              ])
        self.encode2 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 6, out_channels = 16, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(16)
              ])
        self.encode3 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(32)
              ])

        self.encode4 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 32, out_channels = 64, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(64)
              ])

        # デコーダー部分
        self.decode4 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 64, out_channels = 32, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(32)
              ])
        self.decode3 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 32, out_channels = 16, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(16)
              ])
        self.decode2 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 16, out_channels = 6, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(6)
              ])
        self.decode1 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 6, out_channels = 1, kernel_size = 3, padding = 1),
              ])

    def forward(self, x):
        # エンコーダー部分
        dim_0 = x.size()      
        x = F.relu(self.encode1(x))
        x, idx_1 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)

        dim_1 = x.size() 
        x = F.relu(self.encode2(x))
        x, idx_2 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)

        dim_2 = x.size()
        x = F.relu(self.encode3(x)) 
        x, idx_3 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)

        dim_3 = x.size()
        x = F.relu(self.encode4(x)) 
        x, idx_4 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)

        # デコーダー部分
        x = F.max_unpool2d(x, idx_4, kernel_size = 2, stride = 2, output_size = dim_3)
        x = F.relu(self.decode4(x))

        x = F.max_unpool2d(x, idx_3, kernel_size = 2, stride = 2, output_size = dim_2)
        x = F.relu(self.decode3(x))

        x = F.max_unpool2d(x, idx_2, kernel_size = 2, stride = 2, output_size = dim_1)           
        x = F.relu(self.decode2(x))

        x = F.max_unpool2d(x, idx_1, kernel_size = 2, stride = 2, output_size = dim_0)           
        x = F.relu(self.decode1(x))

        x = torch.sigmoid(x)                                     

        return x

学習済みネットワークのダウンロード

前回作った学習済みクマネットワークをダウンロードします。

url = "https://github.com/maskot1977/PythonCourse2019/blob/master/kuma_050_20200226.pytorch?raw=true"
import urllib.request
urllib.request.urlretrieve(url, 'kuma_050_20200226.pytorch') # データのダウンロード
('kuma_050_20200226.pytorch', <http.client.HTTPMessage at 0x7f73177ebef0>)

ロードする

定義したクマネットワーク上に、学習済みクマネットワークをロードします。

kuma = Kuma()
kuma.load_state_dict(torch.load("kuma_050_20200226.pytorch"))
<All keys matched successfully>

中身を確認

kuma
Kuma(
  (encode1): Sequential(
    (0): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (encode2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (encode3): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (encode4): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decode4): Sequential(
    (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decode3): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decode2): Sequential(
    (0): ConvTranspose2d(16, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decode1): Sequential(
    (0): ConvTranspose2d(6, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

パラメーターを1個1個見てみると、

for name, param in kuma.named_parameters():
    print(name, param.shape)
encode1.0.weight torch.Size([6, 1, 3, 3])
encode1.0.bias torch.Size([6])
encode1.1.weight torch.Size([6])
encode1.1.bias torch.Size([6])
encode2.0.weight torch.Size([16, 6, 3, 3])
encode2.0.bias torch.Size([16])
encode2.1.weight torch.Size([16])
encode2.1.bias torch.Size([16])
encode3.0.weight torch.Size([32, 16, 3, 3])
encode3.0.bias torch.Size([32])
encode3.1.weight torch.Size([32])
encode3.1.bias torch.Size([32])
encode4.0.weight torch.Size([64, 32, 3, 3])
encode4.0.bias torch.Size([64])
encode4.1.weight torch.Size([64])
encode4.1.bias torch.Size([64])
decode4.0.weight torch.Size([64, 32, 3, 3])
decode4.0.bias torch.Size([32])
decode4.1.weight torch.Size([32])
decode4.1.bias torch.Size([32])
decode3.0.weight torch.Size([32, 16, 3, 3])
decode3.0.bias torch.Size([16])
decode3.1.weight torch.Size([16])
decode3.1.bias torch.Size([16])
decode2.0.weight torch.Size([16, 6, 3, 3])
decode2.0.bias torch.Size([6])
decode2.1.weight torch.Size([6])
decode2.1.bias torch.Size([6])
decode1.0.weight torch.Size([6, 1, 3, 3])
decode1.0.bias torch.Size([1])

あー そーゆーことね 完全に理解した ←わかってない

カーネル(またはフィルター)

どうやら、上に示した行列パラメーターは、深層学習の世界ではカーネル(またはフィルター)と呼ばれるようです。その行列の形は上に示した通りですが、そこにどんな数値が入っているか可視化してみましょう。

import matplotlib.pyplot as plt
for name, param in kuma.named_parameters():
    print(name)
    print(param.shape)
    if len(param.shape) == 4:
        x, y, z, w = param.shape
        idx = 0
        fig = plt.figure(figsize=(x, y))
        for para in param:
            for par in para:
                idx += 1
                ax = fig.add_subplot(y, x, idx)
                im = ax.imshow(par.detach().numpy(), cmap="gray")
                ax.axis('off')
                #fig.colorbar(im)
        plt.show()
    #break
encode1.0.weight
torch.Size([6, 1, 3, 3])

output_7_1.png
※ 画像をクリックすると拡大できると思います。

encode1.0.bias
torch.Size([6])
encode1.1.weight
torch.Size([6])
encode1.1.bias
torch.Size([6])
encode2.0.weight
torch.Size([16, 6, 3, 3])

output_7_3.png
※ 画像をクリックすると拡大できると思います。

encode2.0.bias
torch.Size([16])
encode2.1.weight
torch.Size([16])
encode2.1.bias
torch.Size([16])
encode3.0.weight
torch.Size([32, 16, 3, 3])

output_7_5.png
※ 画像をクリックすると拡大できると思います。

encode3.0.bias
torch.Size([32])
encode3.1.weight
torch.Size([32])
encode3.1.bias
torch.Size([32])
encode4.0.weight
torch.Size([64, 32, 3, 3])

output_7_7.png
※ 画像をクリックすると拡大できると思います。

encode4.0.bias
torch.Size([64])
encode4.1.weight
torch.Size([64])
encode4.1.bias
torch.Size([64])
decode4.0.weight
torch.Size([64, 32, 3, 3])

output_7_9.png
※ 画像をクリックすると拡大できると思います。

decode4.0.bias
torch.Size([32])
decode4.1.weight
torch.Size([32])
decode4.1.bias
torch.Size([32])
decode3.0.weight
torch.Size([32, 16, 3, 3])

output_7_11.png
※ 画像をクリックすると拡大できると思います。

decode3.0.bias
torch.Size([16])
decode3.1.weight
torch.Size([16])
decode3.1.bias
torch.Size([16])
decode2.0.weight
torch.Size([16, 6, 3, 3])

output_7_13.png
※ 画像をクリックすると拡大できると思います。

decode2.0.bias
torch.Size([6])
decode2.1.weight
torch.Size([6])
decode2.1.bias
torch.Size([6])
decode1.0.weight
torch.Size([6, 1, 3, 3])

output_7_15.png
※ 画像をクリックすると拡大できると思います。

decode1.0.bias
torch.Size([1])

ぼくの足りない知識でいえば、今回のクマネットワークではこの「3 x 3」のカーネル(またはフィルター)で画像をスキャンして、「線」などの画像中の特徴を抜き出しているわけです。

それでも、これでクマさんをどのように認識しているのか、いまいちピンとこない。

クマさん画像をクマネットワークに通す

カーネル(またはフィルター)を見てもピンとこないので、クマさん画像をクマネットワークに通してみて、その各層でクマさん画像がどのように見えているか確認してみました。

クマさん画像の生成

前回と同じやつです。

import numpy as np
import random
from PIL import Image, ImageDraw, ImageFilter
from itertools import product

def draw_bear(n_bear=1): # ランダムにクマさんの画像を生成する
    r = g = b = 250
    im = Image.new('RGB', (400, 400), (r, g, b))
    draw = ImageDraw.Draw(im)

    for _ in range(random.randint(-1, 0)):
        r = random.randint(10, 200)
        g = random.randint(10, 200)
        b = random.randint(10, 200)
        x1 = random.randint(0, 400)
        y1 = random.randint(0, 400)
        dx = random.randint(10, 50)
        dy = random.randint(10, 50)
        draw.ellipse((x1, y1, x1+dx, y1+dy), fill=(r, g, b))

    for _ in range(n_bear):
        r = g = b = 1
        center_x = 200
        center_y = 200
        wx = 60
        wy = 50
        dx1 = 90
        dx2 = 20
        dy1 = 90
        dy2 = 20
        dx3 = 15
        dy3 = 100
        dy4 = 60
        shape1 = (center_x - wx, center_y - wy, center_x + wx, center_y + wy)
        shape2 = (center_x - dx1, center_y - dy1, center_x - dx2, center_y - dy2)
        shape3 = (center_x + dx2, center_y - dy1, center_x + dx1, center_y - dy2)
        shape4 = (center_x - dx3, center_y - dy3, center_x + dx3, center_y - dy4)

        zoom = 0.2 + random.random() * 0.4
        center_x = random.randint(-30, 250)
        center_y = random.randint(-30, 250)

        shape1 = modify(shape1, zoom=zoom, center_x=center_x, center_y=center_y)
        shape2= modify(shape2, zoom=zoom, center_x=center_x, center_y=center_y)
        shape3 = modify(shape3, zoom=zoom, center_x=center_x, center_y=center_y)
        shape4 = modify(shape4, zoom=zoom, center_x=center_x, center_y=center_y)

        draw.ellipse(shape1, fill=(r, g, b))
        draw.ellipse(shape2, fill=(r, g, b))
        draw.ellipse(shape3, fill=(r, g, b))
        #draw.ellipse(shape4, fill=(r, g, b))

    return im

def modify(shape, zoom=1, center_x=0, center_y=0):
    x1, y1, x2, y2 = np.array(shape) * zoom
    return (x1 + center_x, y1 + center_y, x2 + center_x, y2 + center_y)

class Noise: # クマさんの画像にノイズを乗せる
    def __init__(self, input_image):
        self.input_image = input_image
        self.input_pix = self.input_image.load()
        self.w, self.h = self.input_image.size

    def saltpepper(self, salt=0.05, pepper=0.05):
        output_image = Image.new("RGB", self.input_image.size)
        output_pix = output_image.load()

        for x, y in product(*map(range, (self.w, self.h))):
            r = random.random()
            if r < salt:
                output_pix[x, y] = (255, 255, 255)
            elif r > 1 - pepper:
                output_pix[x, y] = (  0,   0,   0)
            else:
                output_pix[x, y] = self.input_pix[x, y]
        return output_image

## クマさんの画像をセマンティック・セグメンテーション用の教師データに加工する
def getdata_for_semantic_segmentation(im): 
    x_im = im.filter(ImageFilter.CONTOUR)
    im2 = Noise(input_image=x_im)
    x_im = im2.saltpepper()
    a_im = np.asarray(im)
    y_im = Image.fromarray(np.where(a_im == 1, 255, 0).astype(dtype='uint8'))
    return x_im, y_im

クマネットワークを少し改変

途中結果を可視化するように改変してみました。

import torch
from torch import nn, optim
from torch.nn import functional as F
class Kuma(nn.Module):
    def __init__(self):
        super(Kuma, self).__init__()
        # エンコーダー部分
        self.encode1 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 1, out_channels = 6, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(6)
              ])
        self.encode2 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 6, out_channels = 16, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(16)
              ])
        self.encode3 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(32)
              ])

        self.encode4 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 32, out_channels = 64, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(64)
              ])

        # デコーダー部分
        self.decode4 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 64, out_channels = 32, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(32)
              ])
        self.decode3 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 32, out_channels = 16, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(16)
              ])
        self.decode2 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 16, out_channels = 6, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(6)
              ])
        self.decode1 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 6, out_channels = 1, kernel_size = 3, padding = 1),
              ])

    def forward(self, x):
        print("forward input:", x.shape)
        draw_layer(x)
        # エンコーダー部分
        dim_0 = x.size()      
        x = F.relu(self.encode1(x))
        x, idx_1 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)
        print("after encode1:", x.shape)
        draw_layer(x)

        dim_1 = x.size() 
        x = F.relu(self.encode2(x))
        x, idx_2 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)
        print("after encode2:", x.shape)
        draw_layer(x)

        dim_2 = x.size()
        x = F.relu(self.encode3(x)) 
        x, idx_3 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)
        print("after encode3:", x.shape)
        draw_layer(x)

        dim_3 = x.size()
        x = F.relu(self.encode4(x)) 
        x, idx_4 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)
        print("after encode4:", x.shape)
        draw_layer(x)

        # デコーダー部分
        x = F.max_unpool2d(x, idx_4, kernel_size = 2, stride = 2, output_size = dim_3)
        x = F.relu(self.decode4(x))
        print("after decode4:", x.shape)
        draw_layer(x)

        x = F.max_unpool2d(x, idx_3, kernel_size = 2, stride = 2, output_size = dim_2)
        x = F.relu(self.decode3(x))
        print("after decode3:", x.shape)
        draw_layer(x)

        x = F.max_unpool2d(x, idx_2, kernel_size = 2, stride = 2, output_size = dim_1)           
        x = F.relu(self.decode2(x))
        print("after decode2:", x.shape)
        draw_layer(x)

        x = F.max_unpool2d(x, idx_1, kernel_size = 2, stride = 2, output_size = dim_0)           
        x = F.relu(self.decode1(x))
        x = torch.sigmoid(x)  
        print("after decode1:", x.shape) 
        draw_layer(x)                                  

        return x

def draw_layer(param):
    if len(param.shape) == 4:
        x, y, z, w = param.shape
        idx = 0
        fig = plt.figure(figsize=(y*2, x*2))
        for para in param:
            for par in para:
                idx += 1
                ax = fig.add_subplot(x, y, idx)
                im = ax.imshow(par.detach().numpy(), cmap="gray")
                ax.axis('off')
                #fig.colorbar(im)
        plt.show()
kuma = Kuma()
kuma.load_state_dict(torch.load("kuma_050_20200226.pytorch"))
<All keys matched successfully>

1枚の画像あたりクマさんが3匹いる画像を、4個作り、途中の層でどのように見えているか確認してみました。

X_test = [] # テスト用の画像データを格納
Y_test = [] # テスト用の正解データを格納
Z_test = [] # テスト用の予測結果を格納

for i in range(4): # 学習に用いなかった新規データを4個生成
    x_im, y_im = getdata_for_semantic_segmentation(draw_bear(3))
    X_test.append(x_im)
    Y_test.append(y_im)

# テスト用の画像データをPyTorch用に整形
X_test_a = np.array([[np.asarray(x).transpose((2, 0, 1))[0]] for x in X_test])
X_test_t = torch.tensor(X_test_a, dtype = torch.float32)

# 学習済みのモデルを使って予測値を計算
Y_pred = kuma(X_test_t)

# 予測値を ndarray として格納
for pred in Y_pred:
    Z_test.append(pred.detach().numpy())
forward input: torch.Size([4, 1, 400, 400])

output_12_1.png
※ 画像をクリックすると拡大できると思います。

after encode1: torch.Size([4, 6, 200, 200])

output_12_3.png
※ 画像をクリックすると拡大できると思います。

after encode2: torch.Size([4, 16, 100, 100])

output_12_5.png
※ 画像をクリックすると拡大できると思います。

after encode3: torch.Size([4, 32, 50, 50])

output_12_7.png
※ 画像をクリックすると拡大できると思います。

after encode4: torch.Size([4, 64, 25, 25])

output_12_9.png
※ 画像をクリックすると拡大できると思います。

after decode4: torch.Size([4, 32, 50, 50])

output_12_11.png
※ 画像をクリックすると拡大できると思います。

after decode3: torch.Size([4, 16, 100, 100])

output_12_13.png
※ 画像をクリックすると拡大できると思います。

after decode2: torch.Size([4, 6, 200, 200])

output_12_15.png
※ 画像をクリックすると拡大できると思います。

after decode1: torch.Size([4, 1, 400, 400])

output_12_17.png
※ 画像をクリックすると拡大できると思います。

エンコーダー第1層で大まかな輪郭をゲットし、第2層〜3層でノイズを除去し、第3〜4層で「輪郭に囲まれた領域」を得て、デコーダー第1〜4層で、それを元の画像上にマッピングできるよう復元していってる、といった感じですかね。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?