LoginSignup
16
12

More than 3 years have passed since last update.

セマンティックセグメンテーションを試してみる(Pytorch)

Last updated at Posted at 2019-12-31

初めに

セマンティックセグメンテーションは画像認識技術の一種で、画素毎に認識することができます。
seg.png

詳しい理論等は別に譲りますが、Pytorchを用いてセマンティックセグメンテーションを試してみたいと思います。今回はSeg-NetとかU-netとかPSP-netのような層が深くて複雑な構造のネットワークではなく、もっと浅くてシンプルで、ノートPCでも十分学習可能なネットワークを扱います。

環境は
CPU: intel(R) core(TM)i5 7200U
メモリ: 8 GB
OS: Windows10
python ver3.6.9
pytorch ver1.3.1
numpy ver1.17.4

データセットの作成

今回は自分で合成した画像を使います。上の線の画像が入力データ,下の塗りつぶした画像が教師データです。すなわち自動でペイントソフトみたいに塗りつぶしをするネットワークを作成します。
input_auto.png
correct_auto.png

学習に必要なデータを作ります。
imgsは入力データ1000枚
imgs_anoが出力データ(教師データ)1000枚
四角と四角は必ず被らないようになっており、辺の長さや四角の数もランダムで決定するようになっています。

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader

def rectangle(img, img_ano, centers, max_side):
    """
    img     …四角形の線のみの2次元画像
    img_ano …あのーテーション画像
    centers …center座標のlist
    max_side…辺の最大長さの1/2 
    """
    if max_side < 3: #max_sideが小さすぎるとき
        max_side = 4
    #辺の長さの1/2を定義
    side_x = np.random.randint(3, int(max_side))
    side_y = np.random.randint(3, int(max_side))    

    #中心の座標,(x, y)を定義
    x = np.random.randint(max_side + 1, img.shape[0] - (max_side + 1))
    y = np.random.randint(max_side + 1, img.shape[1] - (max_side + 1))

    #過去の中心位置と近い位置が含まれた場合,inputデータをそのまま返す
    for center in centers:
        if np.abs(center[0] - x) < (2 *max_side + 1):
            if np.abs(center[1] - y) < (2 * max_side + 1):
                return img, img_ano, centers

    img[x - side_x : x + side_x, y - side_y] = 1.0      #上辺
    img[x - side_x : x + side_x, y + side_y] = 1.0      #下辺
    img[x - side_x, y - side_y : y + side_y] = 1.0      #左辺
    img[x + side_x, y - side_y : y + side_y + 1] = 1.0  #右辺
    img_ano[x - side_x : x + side_x + 1, y - side_y : y + side_y + 1] = 1.0
    centers.append([x, y])
    return img, img_ano, centers


num_images = 1000                                   #生成する画像数
length = 64                                          #画像のサイズ
imgs = np.zeros([num_images, 1, length, length])     #ゼロ行列を生成,入力画像
imgs_ano = np.zeros([num_images, 1, length, length]) #出力画像

for i in range(num_images):
    centers = []
    img = np.zeros([length, length])
    img_ano = np.zeros([64, 64])
    for j in range(6):                       #四角形を最大6つ生成
        img, img_ano, centers = rectangle(img, img_ano, centers, 12) 
    imgs[i, 0, :, :] = img
    imgs_ano[i, 0, :, :] = img_ano

imgs = torch.tensor(imgs, dtype = torch.float32)                 #ndarray - torch.tensor
imgs_ano = torch.tensor(imgs_ano, dtype = torch.float32)           #ndarray - torch.tensor
data_set = TensorDataset(imgs, imgs_ano)
data_loader = DataLoader(data_set, batch_size = 100, shuffle = True)

ネットワーク_1

次にPytorchでネットワークのクラスを作成します。
まずは前回のオートエンコーダで定義したネットワークをそのまま使いました。オートエンコーダもセグメンテーションも入力画像と同じサイズの画像を生成するので、(Pytorchの場合)使えました。Tensorflowとかどうなっているのだろう?

class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        #Encoder Layers
        self.conv1 = nn.Conv2d(in_channels = 1,
                               out_channels = 16,
                               kernel_size = 3,
                               padding = 1)
        self.conv2 = nn.Conv2d(in_channels = 16,
                               out_channels = 4,
                               kernel_size = 3,
                               padding = 1)
        #Decoder Layers
        self.t_conv1 = nn.ConvTranspose2d(in_channels = 4, out_channels = 16,
                                          kernel_size = 2, stride = 2)
        self.t_conv2 = nn.ConvTranspose2d(in_channels = 16, out_channels = 1,
                                          kernel_size = 2, stride = 2)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        #encode#                           
        x = self.relu(self.conv1(x))        
        x = self.pool(x)                  
        x = self.relu(self.conv2(x))      
        x = self.pool(x)                  
        #decode#
        x = self.relu(self.t_conv1(x))    
        x = self.sigmoid(self.t_conv2(x))
        return x

このネットワークで学習させます。

#******ネットワークを選択******
net = ConvAutoencoder()                               
loss_fn = nn.MSELoss()                                #損失関数の定義
optimizer = optim.Adam(net.parameters(), lr = 0.01)

losses = []                                     #epoch毎のlossを記録
epoch_time = 30
for epoch in range(epoch_time):
    running_loss = 0.0                          #epoch毎のlossの計算
    net.train()
    for i, (XX, yy) in enumerate(data_loader):
        optimizer.zero_grad()       
        y_pred = net(XX)
        loss = loss_fn(y_pred, yy)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("epoch:",epoch, " loss:", running_loss/(i + 1))
    losses.append(running_loss/(i + 1))

#lossの可視化
plt.plot(losses)
plt.ylabel("loss")
plt.xlabel("epoch time")
plt.savefig("loss_auto")
plt.show()

epoch毎の損失(loss)を可視化したものです。
epochが30回でそれなりに収束した状態でしょうか?
loss_auto.png

学習に使っていない画像を用いて、試してみます。
大まかな位置は判別できていますが、境界の付近はうまく取れていない印象があります。
output_auto.png

net.eval()            #評価モード
#今まで学習していない画像を1つ生成
num_images = 1
img_test = np.zeros([num_images, 1, length, length])
imgs_test_ano = np.zeros([num_images, 1, length, length])
for i in range(num_images):
    centers = []
    img = np.zeros([length, length])
    img_ano = np.zeros([length, length])
    for j in range(6):
        img, img_ano, centers = rectangle(img, img_ano, centers, 7)
    img_test[i, 0, :, :] = img

img_test = img_test.reshape([1, 1, 64, 64])
img_test = torch.tensor(img_test, dtype = torch.float32)
img_test = net(img_test)             #生成した画像を学習済のネットワークへ
img_test = img_test.detach().numpy() #torch.tensor - ndarray
img_test = img_test[0, 0, :, :]

plt.imshow(img, cmap = "gray")       #inputデータの可視化
plt.savefig("input_auto")
plt.show()
plt.imshow(img_test, cmap = "gray")  #outputデータの可視化
plt.savefig("output_auto")
plt.show()
plt.imshow(img_ano, cmap = "gray")   #正解データ
plt.savefig("correct_auto")
plt.plot()

ネットワークを深くしてみる。

先ほどのモデルでは十分な性能を得ることができませんでしたので、層を深くしてみたいと思います。
ここでは単に深くするだけではなく、過学習を防ぐバッチノーマライゼーションと、デコーダでアップサンプリングを入れてみます。アップサンプリングの詳しい解説はこちらの記事が分かりやすかったです。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #encoder
        self.encoder_conv_1 = nn.Sequential(*[
                                            nn.Conv2d(in_channels = 1, 
                                                      out_channels = 6,
                                                      kernel_size = 3,
                                                      padding = 1),
                                            nn.BatchNorm2d(6)
                                            ])

        self.encoder_conv_2 = nn.Sequential(*[
                                            nn.Conv2d(in_channels = 6,
                                                      out_channels = 16,
                                                      kernel_size = 3,
                                                      padding = 1),
                                            nn.BatchNorm2d(16)
                                            ])
        self.encoder_conv_3 = nn.Sequential(*[
                                            nn.Conv2d(in_channels = 16,
                                                      out_channels = 32,
                                                      kernel_size = 3,
                                                      padding = 1),
                                            nn.BatchNorm2d(32)
                                            ])

        #decoder
        self.decoder_convt_3 = nn.Sequential(*[
                                            nn.ConvTranspose2d(in_channels = 32,
                                                               out_channels = 16,
                                                               kernel_size = 3,
                                                               padding = 1),
                                            nn.BatchNorm2d(16)
                                            ])

        self.decoder_convt_2 = nn.Sequential(*[
                                            nn.ConvTranspose2d(in_channels = 16,
                                                               out_channels = 6,
                                                               kernel_size = 3,
                                                               padding = 1),
                                            nn.BatchNorm2d(6)
                                            ])

        self.decoder_convt_1 = nn.Sequential(*[
                                            nn.ConvTranspose2d(in_channels = 6,
                                                               out_channels = 1,
                                                               kernel_size = 3,
                                                               padding = 1)
                                            ])

    def forward(self, x):
        #encoder
        dim_0 = x.size()                    
        x = F.relu(self.encoder_conv_1(x))                            
        x, indices_1 = F.max_pool2d(x, kernel_size = 2,
                                    stride = 2, 
                                    return_indices = True)  #indiceでmaxpoolの位置を記録          
        dim_1 = x.size()
        x = F.relu(self.encoder_conv_2(x))                            
        x, indices_2 = F.max_pool2d(x, kernel_size = 2,
                                    stride = 2, 
                                    return_indices = True)            

        dim_2 = x.size()
        x = F.relu(self.encoder_conv_3(x))
        x, indices_3 = F.max_pool2d(x, kernel_size = 2,
                                    stride = 2, 
                                    return_indices = True)

        #decoder
        x = F.max_unpool2d(x, indices_3, kernel_size = 2,
                           stride = 2, output_size = dim_2)
        x = F.relu(self.decoder_convt_3(x))

        x = F.max_unpool2d(x, indices_2, kernel_size = 2,
                           stride = 2, output_size = dim_1)           
        x = F.relu(self.decoder_convt_2(x))                           

        x = F.max_unpool2d(x, indices_1, kernel_size = 2,
                           stride = 2, output_size = dim_0)           
        x = F.relu(self.decoder_convt_1(x))                           
        x = torch.sigmoid(x)                                       

        return x

このネットワークに切り替えるのは簡単で

#******ネットワークを選択******
net = ConvAutoencoder()

となっている場所を、新たに作成したクラスに変更するだけです。

#******ネットワークを選択******
net = Net()

lossの変遷をグラフ化します。

loss_auto.png

学習に使用していないデータを入力して、正解画像と比較してみます。
output.png

セグメンテーションできていることが分かります。

終わりに

簡単なセグメンテーションを今回試しました。実用には程遠いシンプルなモデルでしたが、雰囲気は大体つかめたような気がします。

16
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
12