LoginSignup
5
2

More than 1 year has passed since last update.

U-NetからDeepLab V3に変更したときの肺X線画像セグメンテーションの精度改善

Last updated at Posted at 2022-07-06

はじめに

前回、U-Netで胸部X線画像の肺領域のセマンティックセグメンテーションを行いました。

肺のX線画像のセグメンテーション(U-Net)をやってみた

今回はさらなる精度改善を目指し、DeepLab V3を使ってトライしてみたところ、かなり精度が上がったので、記事にしました。

データセット

前回と同じです。

日本放射線技術学会のMINIJSRT_DATABASEのSegmentation01(胸部X線画像の肺野領域抽出肺野領域:255/肺野外:0)を使用しました。

(画像例)
image.png

環境はGoogle Colaboratoryで、PytorchのラッパーであるPytorch-Lightningを使用しました。

Githubにソースコードを公開しておきます。

事前準備

まずはデータのダウンロード。

!wget -q http://imgcom.jsrt.or.jp/imgcom/wp-content/uploads/2018/11/Segmentation01.zip
!unzip -q /content/Segmentation01.zip

続いてライブラリのインポート

あらかじめ、以下のライブラリをインストールしておきます。

# ライブラリのインストール
!pip install -q pytorch_lightning
!pip install -q torchmetrics==0.6.0

Deeplab V3はtorchvision.models.segmentation.deeplabv3を使用します。
Deeplab V3のモデル構造はResNet-101を使います。

# ライブラリのインポート
import numpy as np
import matplotlib.pyplot as plt

from glob import glob
import argparse

import PIL
from PIL import Image

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torchmetrics
from torchmetrics.functional import accuracy, iou

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3

ハイパーパラメータを設定します。

# ハイパーパラメータの設定
parser = argparse.ArgumentParser()
parser.add_argument('--image_size', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=1e-4) # 学習率
parser.add_argument('--patience', type=int, default=10) # earlystoppingの監視対象回数
param = parser.parse_args(args=[])
print(param)

DatasetとDataLoader

データファイルのパスのリストを作成しておきます。

入力画像のリスト(train_img_listtest_img_list)とラベル画像のリスト(train_label_listtest_label_list)を作成します。

# 画像ファイル名リスト
train_img_list = sorted(glob('/content/Segmentation01/train/org/*.png'))
test_img_list = sorted(glob('/content/Segmentation01/test/org/*.png'))

# ラベル画像リスト
train_label_list = sorted(glob('/content/Segmentation01/train/label/*.png'))
test_label_list = sorted(glob('/content/Segmentation01/test/label/*.png'))

続いて、Datasetを定義していきます。
画像はtorchvision.transformsResizeToTensorで前処理します。
今回はデータ拡張(Data augmentation)は行いません。
Datasetは入力画像(image)とラベル画像(label)を返すようにします。

class XrayDataset(data.Dataset):

    def __init__(self, img_path_list, label_path_list):
        self.image_path_list = img_path_list
        self.label_path_list = label_path_list
        self.transform = transforms.Compose( [transforms.Resize((param.image_size, param.image_size)),
                                              transforms.ToTensor(),])


    def __len__(self):
        return len(self.image_path_list)

    
    def __getitem__(self, idx):
        img = Image.open(self.image_path_list[idx]).convert('RGB')
        img = self.transform(img)

        label = Image.open(self.label_path_list[idx])
        label = self.transform(label)


        return img, label

trainとtest用のデータセットのインスタンスを作成します。

# Datasetのインスタンス作成
train_dataset = XrayDataset(train_img_list, train_label_list)
test_dataset = XrayDataset(test_img_list, test_label_list)

DataLoaderを作成します。

# Dataloader
dataloader = {
    'train': data.DataLoader(train_dataset, batch_size=param.batch_size, shuffle=True),
    'val': data.DataLoader(test_dataset, batch_size=param.batch_size, shuffle=False)
}

ネットワークの定義と学習

Pytorch-Lightningを使用します。

Deeplab V3はtorchvision.models.segmentation.deeplabv3_resnet101を使用します。

訓練済みネットワークを読み込み(pretrained=True)、ファインチューニングを行います。

out = self(x)の返り値のoutにはoutauxが辞書型で格納されていますが、outのみを使用するため、lossはout['out']として取り出します。

ログはLoss、Accuracy、IoU(Intersection over Union)を記録します。

class Net(pl.LightningModule):

    def __init__(self, lr: float):
        super().__init__()

        self.lr = lr

        self.model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
        self.model.classifier = deeplabv3.DeepLabHead(2048, 1)


    def forward(self, x):
        h = self.model(x)
        return h


    def training_step(self, batch, batch_idx):
        x, t = batch
        out = self(x)
        y = torch.sigmoid(out['out'])
        loss = F.binary_cross_entropy_with_logits(out['out'], t)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', accuracy(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_iou', iou(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)
        return loss


    def validation_step(self, batch, batch_idx):
        x, t = batch
        out = self(x)
        y = torch.sigmoid(out['out'])
        loss = F.binary_cross_entropy_with_logits(out['out'], t)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', accuracy(y, t.int()), on_step=False, on_epoch=True)
        self.log('val_iou', iou(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)
        return loss


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        return optimizer

今回は最良のモデルを保存するためのModelCheckpointEarlyStoppingも実装していきます。

監視対象は両方とも'val_loss'です。

# callbacksの定義
model_checkpoint = ModelCheckpoint(
    SAVE_MODEL_PATH,
    filename="UNet-"+"{epoch:02d}-{val_loss:.2f}",
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_last=False,
)
early_stopping = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=param.patience,
)

学習を実行します。

# 訓練の実行
pl.seed_everything(0)
net = Net(lr=param.lr)
trainer = pl.Trainer(max_epochs=param.epochs, 
                     callbacks=[model_checkpoint, early_stopping], 
                     gpus=1)
trainer.fit(net, dataloader['train'], dataloader['val'])

TensorBoardで学習過程を可視化します。

%load_ext tensorboard
%tensorboard --logdir lightning_logs/

※結果は省略

結果の可視化

5枚ほどラベル画像(正解)と予測画像を可視化してみます。

# 表示する画像サイズ
n_max_imgs = 5

plt.figure(figsize=(9, 20))
for n in range(n_max_imgs):
    x, t = test_dataset[n]
    net.eval()
    out = net(x.unsqueeze(0)) # 最初の位置(0)に新たな次元を挿入
    y = torch.sigmoid(out['out'])
    y_label = (y > 0.5).int().squeeze()

    t = np.squeeze(t) # numpyでtensorの余分な次元を除去

    plt.subplot(n_max_imgs, 2, 2*n+1)
    plt.imshow(t, cmap='gray')

    plt.subplot(n_max_imgs, 2, 2*n+2)
    plt.imshow(y_label, cmap='gray')

image.png

画像を見てみると予測はかなりうまくいっているようです。

各評価指標も見てみます。

# 訓練データと検証データに対する最終的な結果を表示
trainer.callback_metrics
OUT
{'train_acc': tensor(0.9928),
 'train_acc_epoch': tensor(0.9928),
 'train_acc_step': tensor(0.9928),
 'train_iou': tensor(0.9829),
 'train_iou_epoch': tensor(0.9829),
 'train_iou_step': tensor(0.9828),
 'train_loss': tensor(0.0181),
 'train_loss_epoch': tensor(0.0181),
 'train_loss_step': tensor(0.0180),
 'val_acc': tensor(0.9828),
 'val_iou': tensor(0.9614),
 'val_iou_epoch': tensor(0.9614),
 'val_loss': tensor(0.0514)}

検証データに対して確認を行ったところ IoU の値が0.9614となっており、予測精度は良い結果でした。

前回、U-Netを使ったときは0.948であり、精度が改善されていることがわかります。

肺のX線画像のセグメンテーション(U-Net)をやってみた

5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2