はじめに
前回、U-Netで胸部X線画像の肺領域のセマンティックセグメンテーションを行いました。
今回はさらなる精度改善を目指し、DeepLab V3を使ってトライしてみたところ、かなり精度が上がったので、記事にしました。
データセット
前回と同じです。
日本放射線技術学会のMINIJSRT_DATABASEのSegmentation01(胸部X線画像の肺野領域抽出肺野領域:255/肺野外:0)を使用しました。
環境はGoogle Colaboratoryで、PytorchのラッパーであるPytorch-Lightningを使用しました。
Githubにソースコードを公開しておきます。
事前準備
まずはデータのダウンロード。
!wget -q http://imgcom.jsrt.or.jp/imgcom/wp-content/uploads/2018/11/Segmentation01.zip
!unzip -q /content/Segmentation01.zip
続いてライブラリのインポート
あらかじめ、以下のライブラリをインストールしておきます。
# ライブラリのインストール
!pip install -q pytorch_lightning
!pip install -q torchmetrics==0.6.0
Deeplab V3はtorchvision.models.segmentation.deeplabv3
を使用します。
Deeplab V3のモデル構造はResNet-101を使います。
# ライブラリのインポート
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import argparse
import PIL
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torchmetrics
from torchmetrics.functional import accuracy, iou
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3
ハイパーパラメータを設定します。
# ハイパーパラメータの設定
parser = argparse.ArgumentParser()
parser.add_argument('--image_size', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=1e-4) # 学習率
parser.add_argument('--patience', type=int, default=10) # earlystoppingの監視対象回数
param = parser.parse_args(args=[])
print(param)
DatasetとDataLoader
データファイルのパスのリストを作成しておきます。
入力画像のリスト(train_img_list
、test_img_list
)とラベル画像のリスト(train_label_list
、test_label_list
)を作成します。
# 画像ファイル名リスト
train_img_list = sorted(glob('/content/Segmentation01/train/org/*.png'))
test_img_list = sorted(glob('/content/Segmentation01/test/org/*.png'))
# ラベル画像リスト
train_label_list = sorted(glob('/content/Segmentation01/train/label/*.png'))
test_label_list = sorted(glob('/content/Segmentation01/test/label/*.png'))
続いて、Datasetを定義していきます。
画像はtorchvision.transforms
のResize
とToTensor
で前処理します。
今回はデータ拡張(Data augmentation)は行いません。
Datasetは入力画像(image
)とラベル画像(label
)を返すようにします。
class XrayDataset(data.Dataset):
def __init__(self, img_path_list, label_path_list):
self.image_path_list = img_path_list
self.label_path_list = label_path_list
self.transform = transforms.Compose( [transforms.Resize((param.image_size, param.image_size)),
transforms.ToTensor(),])
def __len__(self):
return len(self.image_path_list)
def __getitem__(self, idx):
img = Image.open(self.image_path_list[idx]).convert('RGB')
img = self.transform(img)
label = Image.open(self.label_path_list[idx])
label = self.transform(label)
return img, label
trainとtest用のデータセットのインスタンスを作成します。
# Datasetのインスタンス作成
train_dataset = XrayDataset(train_img_list, train_label_list)
test_dataset = XrayDataset(test_img_list, test_label_list)
DataLoaderを作成します。
# Dataloader
dataloader = {
'train': data.DataLoader(train_dataset, batch_size=param.batch_size, shuffle=True),
'val': data.DataLoader(test_dataset, batch_size=param.batch_size, shuffle=False)
}
ネットワークの定義と学習
Pytorch-Lightningを使用します。
Deeplab V3はtorchvision.models.segmentation.deeplabv3_resnet101
を使用します。
訓練済みネットワークを読み込み(pretrained=True
)、ファインチューニングを行います。
out = self(x)
の返り値のout
にはout
とaux
が辞書型で格納されていますが、out
のみを使用するため、lossはout['out']
として取り出します。
ログはLoss、Accuracy、IoU(Intersection over Union)を記録します。
class Net(pl.LightningModule):
def __init__(self, lr: float):
super().__init__()
self.lr = lr
self.model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
self.model.classifier = deeplabv3.DeepLabHead(2048, 1)
def forward(self, x):
h = self.model(x)
return h
def training_step(self, batch, batch_idx):
x, t = batch
out = self(x)
y = torch.sigmoid(out['out'])
loss = F.binary_cross_entropy_with_logits(out['out'], t)
self.log('train_loss', loss, on_step=True, on_epoch=True)
self.log('train_acc', accuracy(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)
self.log('train_iou', iou(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, t = batch
out = self(x)
y = torch.sigmoid(out['out'])
loss = F.binary_cross_entropy_with_logits(out['out'], t)
self.log('val_loss', loss, on_step=False, on_epoch=True)
self.log('val_acc', accuracy(y, t.int()), on_step=False, on_epoch=True)
self.log('val_iou', iou(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters())
return optimizer
今回は最良のモデルを保存するためのModelCheckpoint
とEarlyStopping
も実装していきます。
監視対象は両方とも'val_loss'
です。
# callbacksの定義
model_checkpoint = ModelCheckpoint(
SAVE_MODEL_PATH,
filename="UNet-"+"{epoch:02d}-{val_loss:.2f}",
monitor='val_loss',
mode='min',
save_top_k=1,
save_last=False,
)
early_stopping = EarlyStopping(
monitor='val_loss',
mode='min',
patience=param.patience,
)
学習を実行します。
# 訓練の実行
pl.seed_everything(0)
net = Net(lr=param.lr)
trainer = pl.Trainer(max_epochs=param.epochs,
callbacks=[model_checkpoint, early_stopping],
gpus=1)
trainer.fit(net, dataloader['train'], dataloader['val'])
TensorBoardで学習過程を可視化します。
%load_ext tensorboard
%tensorboard --logdir lightning_logs/
※結果は省略
結果の可視化
5枚ほどラベル画像(正解)と予測画像を可視化してみます。
# 表示する画像サイズ
n_max_imgs = 5
plt.figure(figsize=(9, 20))
for n in range(n_max_imgs):
x, t = test_dataset[n]
net.eval()
out = net(x.unsqueeze(0)) # 最初の位置(0)に新たな次元を挿入
y = torch.sigmoid(out['out'])
y_label = (y > 0.5).int().squeeze()
t = np.squeeze(t) # numpyでtensorの余分な次元を除去
plt.subplot(n_max_imgs, 2, 2*n+1)
plt.imshow(t, cmap='gray')
plt.subplot(n_max_imgs, 2, 2*n+2)
plt.imshow(y_label, cmap='gray')
画像を見てみると予測はかなりうまくいっているようです。
各評価指標も見てみます。
# 訓練データと検証データに対する最終的な結果を表示
trainer.callback_metrics
{'train_acc': tensor(0.9928),
'train_acc_epoch': tensor(0.9928),
'train_acc_step': tensor(0.9928),
'train_iou': tensor(0.9829),
'train_iou_epoch': tensor(0.9829),
'train_iou_step': tensor(0.9828),
'train_loss': tensor(0.0181),
'train_loss_epoch': tensor(0.0181),
'train_loss_step': tensor(0.0180),
'val_acc': tensor(0.9828),
'val_iou': tensor(0.9614),
'val_iou_epoch': tensor(0.9614),
'val_loss': tensor(0.0514)}
検証データに対して確認を行ったところ IoU の値が0.9614となっており、予測精度は良い結果でした。
前回、U-Netを使ったときは0.948であり、精度が改善されていることがわかります。