13
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

肺のX線画像のセグメンテーション(U-Net)をやってみた

Last updated at Posted at 2022-06-30

データセット

日本放射線技術学会のMINIJSRT_DATABASEのSegmentation01(胸部X線画像の肺野領域抽出肺野領域:255/肺野外:0)を使用しました。

(画像例)
image.png

環境はGoogle Colaboratoryで、PytorchのラッパーであるPytorch-Lightningを使用しました。

Githubにソースコードを公開しておきます。

semantic_segmentation_chest-Xray_Pytorch-Lightning

事前準備

まずはデータのダウンロード。

!wget -q http://imgcom.jsrt.or.jp/imgcom/wp-content/uploads/2018/11/Segmentation01.zip
!unzip -q /content/Segmentation01.zip

続いてライブラリのインポート

あらかじめ、以下のライブラリをインストールしておきます。

# ライブラリのインストール
!pip install -q pytorch_lightning
!pip install -q monai
!pip install -q torchmetrics==0.6.0

今回はMONAIに組み込まれているU-Netを実装に使っていきます。

# ライブラリのインポート
import numpy as np
import matplotlib.pyplot as plt

from glob import glob
import argparse

import PIL
from PIL import Image

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torchmetrics
from torchmetrics.functional import accuracy, iou

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import transforms

import monai
from monai.networks.blocks import Convolution
from monai.networks.nets import UNet

ハイパーパラメータを設定します。

# ハイパーパラメータの設定
parser = argparse.ArgumentParser()
parser.add_argument('--image_size', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=1e-4) # 学習率
parser.add_argument('--patience', type=int, default=10) # earlystoppingの監視対象回数
param = parser.parse_args(args=[])
print(param)

DatasetとDataLoader

データファイルのパスのリストを作成しておきます。

入力画像のリスト(train_img_listtest_img_list)とラベル画像のリスト(train_label_listtest_label_list)を作成します。

# 画像ファイル名リスト
train_img_list = sorted(glob('/content/Segmentation01/train/org/*.png'))
test_img_list = sorted(glob('/content/Segmentation01/test/org/*.png'))

# ラベル画像リスト
train_label_list = sorted(glob('/content/Segmentation01/train/label/*.png'))
test_label_list = sorted(glob('/content/Segmentation01/test/label/*.png'))

続いて、Datasetを定義していきます。
画像はtorchvision.transformsResizeToTensorで前処理します。
今回はデータ拡張(Data augmentation)は行いません。
Datasetは入力画像(image)とラベル画像(label)を返すようにします。

class XrayDataset(data.Dataset):

    def __init__(self, img_path_list, label_path_list):
        self.image_path_list = img_path_list
        self.label_path_list = label_path_list
        self.transform = transforms.Compose( [transforms.Resize((param.image_size, param.image_size)),
                                              transforms.ToTensor(),])


    def __len__(self):
        return len(self.image_path_list)

    
    def __getitem__(self, idx):
        img = Image.open(self.image_path_list[idx])
        img = self.transform(img)

        label = Image.open(self.label_path_list[idx])
        label = self.transform(label)


        return img, label

trainとtest用のデータセットのインスタンスを作成します。

# Datasetのインスタンス作成
train_dataset = XrayDataset(train_img_list, train_label_list)
test_dataset = XrayDataset(test_img_list, test_label_list)

DataLoaderを作成します。

# Dataloader
dataloader = {
    'train': data.DataLoader(train_dataset, batch_size=param.batch_size, shuffle=True),
    'val': data.DataLoader(test_dataset, batch_size=param.batch_size, shuffle=False)
}

ネットワークの定義と学習

Pytorch-Lightningを使用します。

U-NetはMONAIに組み込まれているものを使用します。

ログはLoss、Accuracy、IoU(Intersection over Union)を記録します。

class Net(pl.LightningModule):

    def __init__(self, lr: float):
        super().__init__()

        self.lr = lr

        self.unet = UNet(
                  dimensions=2, in_channels=1, out_channels=1,
                  channels=(64, 128, 256, 512, 1024),
                  strides=(2, 2, 2, 2, 2)
        )


    def forward(self, x):
        h = self.unet(x)
        return h


    def training_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.binary_cross_entropy_with_logits(y, t)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', accuracy(y.sigmoid(), t.int()), on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_iou', iou(y.sigmoid(), t.int()), on_step=True, on_epoch=True, prog_bar=True)
        return loss


    def validation_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.binary_cross_entropy_with_logits(y, t)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', accuracy(y.sigmoid(), t.int()), on_step=False, on_epoch=True)
        self.log('val_iou', iou(y.sigmoid(), t.int()), on_step=True, on_epoch=True, prog_bar=True)
        return loss


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        return optimizer

今回は最良のモデルを保存するためのModelCheckpointEarlyStoppingも実装していきます。

監視対象は両方とも'val_loss'

# callbacksの定義
model_checkpoint = ModelCheckpoint(
    SAVE_MODEL_PATH,
    filename="UNet-"+"{epoch:02d}-{val_loss:.2f}",
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_last=False,
)
early_stopping = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=param.patience,
)

学習を実行します。

# 訓練の実行
pl.seed_everything(0)
net = Net(lr=param.lr)
trainer = pl.Trainer(max_epochs=param.epochs, 
                     callbacks=[model_checkpoint, early_stopping], 
                     gpus=1)
trainer.fit(net, dataloader['train'], dataloader['val'])

TensorBoardで学習過程を可視化します。

%load_ext tensorboard
%tensorboard --logdir lightning_logs/

※結果は省略

結果の可視化

5枚ほどラベル画像(正解)と予測画像を可視化してみます。

# 表示する画像サイズ
n_max_imgs = 5

plt.figure(figsize=(9, 20))
for n in range(n_max_imgs):
    x, t = test_dataset[n]
    y = net(x.unsqueeze(0)) # 最初の位置(0)に新たな次元を挿入
    y_label = (y > 0.5).int().squeeze()

    t = np.squeeze(t) # numpyでtensorの余分な次元を除去

    plt.subplot(n_max_imgs, 2, 2*n+1)
    plt.imshow(t, cmap='gray')

    plt.subplot(n_max_imgs, 2, 2*n+2)
    plt.imshow(y_label, cmap='gray')

image.png

画像を見てみるとそこそこ予測はうまく行っているようです。
ただし、部分的に抜けていたり、ノイズとなっている領域が観察されます。

各評価指標も見てみます。

# 訓練データと検証データに対する最終的な結果を表示
trainer.callback_metrics
Out
{'train_acc': tensor(0.9913),
 'train_acc_epoch': tensor(0.9913),
 'train_acc_step': tensor(0.9890),
 'train_iou': tensor(0.9794),
 'train_iou_epoch': tensor(0.9794),
 'train_iou_step': tensor(0.9734),
 'train_loss': tensor(0.0216),
 'train_loss_epoch': tensor(0.0216),
 'train_loss_step': tensor(0.0262),
 'val_acc': tensor(0.9768),
 'val_iou': tensor(0.9483),
 'val_iou_epoch': tensor(0.9483),
 'val_loss': tensor(0.0678)}

検証データに対して確認を行ったところ IoU の値が0.948となっており、予測精度はそこそこ良い結果でした。

続いてはDeepLab V3を使って、より高い予測精度が出せるか検証してみたいと思います。

U-NetからDeepLab V3に変更したときの肺X線画像セグメンテーションの精度改善

13
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?