データセット
日本放射線技術学会のMINIJSRT_DATABASEのSegmentation01(胸部X線画像の肺野領域抽出肺野領域:255/肺野外:0)を使用しました。
環境はGoogle Colaboratoryで、PytorchのラッパーであるPytorch-Lightningを使用しました。
Githubにソースコードを公開しておきます。
semantic_segmentation_chest-Xray_Pytorch-Lightning
事前準備
まずはデータのダウンロード。
!wget -q http://imgcom.jsrt.or.jp/imgcom/wp-content/uploads/2018/11/Segmentation01.zip
!unzip -q /content/Segmentation01.zip
続いてライブラリのインポート
あらかじめ、以下のライブラリをインストールしておきます。
# ライブラリのインストール
!pip install -q pytorch_lightning
!pip install -q monai
!pip install -q torchmetrics==0.6.0
今回はMONAIに組み込まれているU-Netを実装に使っていきます。
# ライブラリのインポート
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import argparse
import PIL
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torchmetrics
from torchmetrics.functional import accuracy, iou
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import transforms
import monai
from monai.networks.blocks import Convolution
from monai.networks.nets import UNet
ハイパーパラメータを設定します。
# ハイパーパラメータの設定
parser = argparse.ArgumentParser()
parser.add_argument('--image_size', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=1e-4) # 学習率
parser.add_argument('--patience', type=int, default=10) # earlystoppingの監視対象回数
param = parser.parse_args(args=[])
print(param)
DatasetとDataLoader
データファイルのパスのリストを作成しておきます。
入力画像のリスト(train_img_list
、test_img_list
)とラベル画像のリスト(train_label_list
、test_label_list
)を作成します。
# 画像ファイル名リスト
train_img_list = sorted(glob('/content/Segmentation01/train/org/*.png'))
test_img_list = sorted(glob('/content/Segmentation01/test/org/*.png'))
# ラベル画像リスト
train_label_list = sorted(glob('/content/Segmentation01/train/label/*.png'))
test_label_list = sorted(glob('/content/Segmentation01/test/label/*.png'))
続いて、Datasetを定義していきます。
画像はtorchvision.transforms
のResize
とToTensor
で前処理します。
今回はデータ拡張(Data augmentation)は行いません。
Datasetは入力画像(image
)とラベル画像(label
)を返すようにします。
class XrayDataset(data.Dataset):
def __init__(self, img_path_list, label_path_list):
self.image_path_list = img_path_list
self.label_path_list = label_path_list
self.transform = transforms.Compose( [transforms.Resize((param.image_size, param.image_size)),
transforms.ToTensor(),])
def __len__(self):
return len(self.image_path_list)
def __getitem__(self, idx):
img = Image.open(self.image_path_list[idx])
img = self.transform(img)
label = Image.open(self.label_path_list[idx])
label = self.transform(label)
return img, label
trainとtest用のデータセットのインスタンスを作成します。
# Datasetのインスタンス作成
train_dataset = XrayDataset(train_img_list, train_label_list)
test_dataset = XrayDataset(test_img_list, test_label_list)
DataLoaderを作成します。
# Dataloader
dataloader = {
'train': data.DataLoader(train_dataset, batch_size=param.batch_size, shuffle=True),
'val': data.DataLoader(test_dataset, batch_size=param.batch_size, shuffle=False)
}
ネットワークの定義と学習
Pytorch-Lightningを使用します。
U-NetはMONAIに組み込まれているものを使用します。
ログはLoss、Accuracy、IoU(Intersection over Union)を記録します。
class Net(pl.LightningModule):
def __init__(self, lr: float):
super().__init__()
self.lr = lr
self.unet = UNet(
dimensions=2, in_channels=1, out_channels=1,
channels=(64, 128, 256, 512, 1024),
strides=(2, 2, 2, 2, 2)
)
def forward(self, x):
h = self.unet(x)
return h
def training_step(self, batch, batch_idx):
x, t = batch
y = self(x)
loss = F.binary_cross_entropy_with_logits(y, t)
self.log('train_loss', loss, on_step=True, on_epoch=True)
self.log('train_acc', accuracy(y.sigmoid(), t.int()), on_step=True, on_epoch=True, prog_bar=True)
self.log('train_iou', iou(y.sigmoid(), t.int()), on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, t = batch
y = self(x)
loss = F.binary_cross_entropy_with_logits(y, t)
self.log('val_loss', loss, on_step=False, on_epoch=True)
self.log('val_acc', accuracy(y.sigmoid(), t.int()), on_step=False, on_epoch=True)
self.log('val_iou', iou(y.sigmoid(), t.int()), on_step=True, on_epoch=True, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters())
return optimizer
今回は最良のモデルを保存するためのModelCheckpoint
とEarlyStopping
も実装していきます。
監視対象は両方とも'val_loss'
# callbacksの定義
model_checkpoint = ModelCheckpoint(
SAVE_MODEL_PATH,
filename="UNet-"+"{epoch:02d}-{val_loss:.2f}",
monitor='val_loss',
mode='min',
save_top_k=1,
save_last=False,
)
early_stopping = EarlyStopping(
monitor='val_loss',
mode='min',
patience=param.patience,
)
学習を実行します。
# 訓練の実行
pl.seed_everything(0)
net = Net(lr=param.lr)
trainer = pl.Trainer(max_epochs=param.epochs,
callbacks=[model_checkpoint, early_stopping],
gpus=1)
trainer.fit(net, dataloader['train'], dataloader['val'])
TensorBoardで学習過程を可視化します。
%load_ext tensorboard
%tensorboard --logdir lightning_logs/
※結果は省略
結果の可視化
5枚ほどラベル画像(正解)と予測画像を可視化してみます。
# 表示する画像サイズ
n_max_imgs = 5
plt.figure(figsize=(9, 20))
for n in range(n_max_imgs):
x, t = test_dataset[n]
y = net(x.unsqueeze(0)) # 最初の位置(0)に新たな次元を挿入
y_label = (y > 0.5).int().squeeze()
t = np.squeeze(t) # numpyでtensorの余分な次元を除去
plt.subplot(n_max_imgs, 2, 2*n+1)
plt.imshow(t, cmap='gray')
plt.subplot(n_max_imgs, 2, 2*n+2)
plt.imshow(y_label, cmap='gray')
画像を見てみるとそこそこ予測はうまく行っているようです。
ただし、部分的に抜けていたり、ノイズとなっている領域が観察されます。
各評価指標も見てみます。
# 訓練データと検証データに対する最終的な結果を表示
trainer.callback_metrics
{'train_acc': tensor(0.9913),
'train_acc_epoch': tensor(0.9913),
'train_acc_step': tensor(0.9890),
'train_iou': tensor(0.9794),
'train_iou_epoch': tensor(0.9794),
'train_iou_step': tensor(0.9734),
'train_loss': tensor(0.0216),
'train_loss_epoch': tensor(0.0216),
'train_loss_step': tensor(0.0262),
'val_acc': tensor(0.9768),
'val_iou': tensor(0.9483),
'val_iou_epoch': tensor(0.9483),
'val_loss': tensor(0.0678)}
検証データに対して確認を行ったところ IoU の値が0.948となっており、予測精度はそこそこ良い結果でした。
続いてはDeepLab V3を使って、より高い予測精度が出せるか検証してみたいと思います。