1
4

More than 3 years have passed since last update.

クマ...じゃなくて セマンティック・セグメンテーション

Last updated at Posted at 2020-01-24

画像を扱うディープラーニング(深層学習)のうち、注目しているモノを抜き出す「セマンティック・セグメンテーション」について勉強してみました。

クマ画像の生成

セマンティック・セグメンテーションを学ぶにあたって、最初に困ったのが、画像の取得。いい感じのがなかなか見当たらなかったので、練習用の画像を自動生成するところから行いました。

コアラとクマの画像を自動生成する で作成した関数を用いて、セマンティック・セグメンテーションのための「画像データ」と「正解データ」を自作します。

from PIL import ImageFilter
import numpy as np
def getdata_for_semantic_segmentation(im):
    x_im = im.filter(ImageFilter.CONTOUR) # 輪郭をとったものを「画像データ」し入力に用いる
    a_im = np.asarray(im) # numpy に変換
    # 黒クマさんを白クマさんにし、それ以外を黒にしたものを「正解データ」とする
    y_im = Image.fromarray(np.where(a_im == 1, 255, 0).astype(dtype='uint8'))
    return x_im, y_im

次のようにして、2000個のデータセットを作成しました。

X_data = [] # 画像データ格納用
Y_data = [] # 正解データ格納用
for i in range(2000): # 画像を2000個生成する
    # クマの画像を生成
    im = koala_or_bear(bear=True, rotate=True , resize=64, others=True)
    # セマンティック・セグメンテーション用に加工
    x_im, y_im = getdata_for_semantic_segmentation(im)
    X_data.append(x_im) # 画像データ
    Y_data.append(y_im) # 正解データ

作成した画像データ、正解データの最初の8件だけ図示して確認。

%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10,10))
for i in range(16):
    ax = fig.add_subplot(4, 4, i+1)
    ax.axis('off')
    if i < 8: # 画像データのトップ8を表示
        ax.set_title('input_{}'.format(i))
        ax.imshow(X_data[i],cmap=plt.cm.gray, interpolation='none')
    else: # 正解データのトップ8を表示
        ax.set_title('answer_{}'.format(i - 8))
        ax.imshow(Y_data[i - 8],cmap=plt.cm.gray, interpolation='none')
plt.show()

output_3_0.png

セマンティック・セグメンテーションの目的は、上の「画像データ」から、クマさんに相当する部分を抜き出して「正解データ」を出力するようなモデルを構築することです。

クマンティック・セグメンテーションモデルの構築

データの整形

import torch
from torch.utils.data import TensorDataset, DataLoader

# 画像データと正解データを ndarray に変換
X_a = np.array([[np.asarray(x).transpose((2, 0, 1))[0]] for x in X_data])
Y_a = np.array([[np.asarray(y).transpose((2, 0, 1))[0]] for y in Y_data])

# ndarray の画像データと正解データを tensor に変換
X_t = torch.tensor(X_a, dtype = torch.float32)               
Y_t = torch.tensor(Y_a, dtype = torch.float32)

# PyTorch で学習するためにデータローダーに格納
data_set = TensorDataset(X_t, Y_t)
data_loader = DataLoader(data_set, batch_size = 100, shuffle = True)

モデルの定義

セマンティック・セグメンテーションを行うモデルは、基本的には、畳み込みニューラルネットワーク (CNN)を使ったオートエンコーダーです。

from torch import nn, optim
from torch.nn import functional as F
class Kuma(nn.Module):
    def __init__(self):
        super(Kuma, self).__init__()
        # エンコーダー部分
        self.encode1 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 1, out_channels = 6, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(6)
              ])
        self.encode2 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 6, out_channels = 16, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(16)
              ])
        self.encode3 = nn.Sequential(
            *[
              nn.Conv2d(
                  in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(32)
              ])

        # デコーダー部分
        self.decode3 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 32, out_channels = 16, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(16)
              ])
        self.decode2 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 16, out_channels = 6, kernel_size = 3, padding = 1),
              nn.BatchNorm2d(6)
              ])
        self.decode1 = nn.Sequential(
            *[
              nn.ConvTranspose2d(
                  in_channels = 6, out_channels = 1, kernel_size = 3, padding = 1),
              ])

    def forward(self, x):
        # エンコーダー部分
        dim_0 = x.size() # デコーダー第1層でサイズを元に戻すとき用             
        x = F.relu(self.encode1(x))
        # return_indices = True にして、デコーダーで max_pool の位置idxを用いる
        x, idx_1 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)
        dim_1 = x.size() # デコーダー第2層でサイズを元に戻すとき用
        x = F.relu(self.encode2(x))
        # return_indices = True にして、デコーダーで max_pool の位置idxを用いる                       
        x, idx_2 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)            
        dim_2 = x.size()
        x = F.relu(self.encode3(x)) # デコーダー第3層でサイズを元に戻すとき用
        # return_indices = True にして、デコーダーで max_pool の位置idxを用いる
        x, idx_3 = F.max_pool2d(x, kernel_size = 2, stride = 2, return_indices = True)

        # デコーダー部分
        x = F.max_unpool2d(x, idx_3, kernel_size = 2, stride = 2, output_size = dim_2)
        x = F.relu(self.decode3(x))
        x = F.max_unpool2d(x, idx_2, kernel_size = 2, stride = 2, output_size = dim_1)           
        x = F.relu(self.decode2(x))                           
        x = F.max_unpool2d(x, idx_1, kernel_size = 2, stride = 2, output_size = dim_0)           
        x = F.relu(self.decode1(x))                           
        x = torch.sigmoid(x)                                     

        return x

学習

%%time

kuma = Kuma()
loss_fn = nn.MSELoss()                               
optimizer = optim.Adam(kuma.parameters(), lr = 0.01)

total_loss_history = []                                     
epoch_time = 50
for epoch in range(epoch_time):
    total_loss = 0.0                          
    kuma.train()
    for i, (XX, yy) in enumerate(data_loader):
        optimizer.zero_grad()       
        y_pred = kuma(XX)
        loss = loss_fn(y_pred, yy)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print("epoch:",epoch, " loss:", total_loss/(i + 1))
    total_loss_history.append(total_loss/(i + 1))

plt.plot(total_loss_history)
plt.ylabel("loss")
plt.xlabel("epoch time")
plt.savefig("total_loss_history")
plt.show()
epoch: 0  loss: 8388.166772460938
epoch: 2  loss: 8372.164868164062
epoch: 3  loss: 8372.035913085938
...
epoch: 48  loss: 8371.781372070312
epoch: 49  loss: 8371.78125

損失関数の値がとんでもない数値になってますが、大丈夫でしょうか...

output_6_1.png

収束したようです。計算時間は次のとおり。

CPU times: user 6min 7s, sys: 8.1 s, total: 6min 16s
Wall time: 6min 16s

結果発表

テスト用のデータとして、新たにデータを生成します。

X_test = [] # テスト用の画像データを格納
Y_test = [] # テスト用の正解データを格納
Z_test = [] # テスト用の予測結果を格納

for i in range(100): # 学習に用いなかった新規データを100個生成
    im = koala_or_bear(bear=True, rotate=True, resize=64, others=True)
    x_im, y_im = getdata_for_semantic_segmentation(im)
    X_test.append(x_im)
    Y_test.append(y_im)

学習済みのモデルを用いて予測します。

# テスト用の画像データをPyTorch用に整形
X_test_a = np.array([[np.asarray(x).transpose((2, 0, 1))[0]] for x in X_test])
X_test_t = torch.tensor(X_test_a, dtype = torch.float32)

# 学習済みのモデルを使って予測値を計算
Y_pred = kuma(X_test_t)

# 予測値を ndarray として格納
for pred in Y_pred:
    Z_test.append(pred.detach().numpy())

先頭10データの描画

どういう予測結果か、先頭10データだけ描画して見てみましょう。左から順に、入力画像データ、正解データ、予測データです。

# データの先頭10個に対して、画像データ、正解データ、予測値を描画
fig = plt.figure(figsize=(6,18))
for i in range(10):
    ax = fig.add_subplot(10, 3, (i * 3)+1)
    ax.axis('off')
    ax.set_title('input_{}'.format(i))
    ax.imshow(X_test[i])
    ax = fig.add_subplot(10, 3, (i * 3)+2)
    ax.axis('off')
    ax.set_title('answer_{}'.format(i))
    ax.imshow(Y_test[i])
    ax = fig.add_subplot(10, 3, (i * 3)+3)
    ax.axis('off')
    ax.set_title('predicted_{}'.format(i))
    yp2 = Y_pred[i].detach().numpy()[0] * 255
    z_im = Image.fromarray(np.array([yp2, yp2, yp2]).transpose((1, 2, 0)).astype(dtype='uint8'))
    ax.imshow(z_im)
plt.show()

output_7_0.png

クマさんの部分を切り出せています。クマさんの大きさが変化しても、回転しても大丈夫です!

ですが、ちょっと大きめに切り出す傾向にあるようです。また、間違えている部分もけっこう見受けられます。

切り出し面積

正解データと予測データで、白く切り出した面積を比較してみましょう。

A_ans = []
A_pred = []
for yt, zt in zip(Y_test, Z_test):
    # 正解の白の面積(ベクトルが3色分あるので3で割る)
    A_ans.append(np.where(np.asarray(yt) > 0.5, 1, 0).sum() / 3) 
    A_pred.append(np.where(np.asarray(zt) > 0.5, 1, 0).sum()) # 予測値の白の面積

plt.figure(figsize=(4, 4))
plt.scatter(A_ans, A_pred, alpha=0.5)
plt.grid()
plt.xlabel('Observed sizes of bears')
plt.ylabel('Predicted sizes of bears')
plt.xlim([0, 1700])
plt.ylim([0, 1700])
plt.show()

output_8_0.png

正解値と予測値がほぼ直線関係にあるのは良いですが、予測値は大きく出る傾向があるようです。

クマさんの大きさだけ知りたい場合は、この関係を元に補正を行うと良いでしょう。ですが、切り出しをより正確に行うためには、もっと複雑なモデルを構築する必要があることでしょう。

1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4