LoginSignup
9
5

More than 1 year has passed since last update.

【Pytorch】3DのUNetを実装する

Last updated at Posted at 2022-07-24

はじめに

【前回】UNetを実装する
本記事は前回の記事の続きとなります。前回はMRIの各断面の画像から小腸・大腸・胃の領域を予測する為に2DのUNetを実装しました。
しかし、MRI画像は本質的には幅×高さ×深さの3Dの情報を有しており、2DのUNetではこれを幅×高さの2Dの画像として学習するため、深さ方向の情報を失っていると考えられます。そこで今回は3DでUNetを実装し、2Dと同様に臓器の領域を予測することが可能か調べました。

UNet

【参考】セグメンテーションのモデル
【原著論文】U-Net: Convolutional Networks for Biomedical Image Segmentation

UNetはセグメンテーションと呼ばれるタスクを処理するために考案されたモデルです。セグメンテーションとは画像の1ピクセルごとにどのクラスに属するか予測するタスクであり、代表的な応用事例として、自動運転・医療画像解析が挙げられます。

セグメンテーションを行うモデルはいくつかの構造に大別されますが、UNetはエンコーダ・デコーダ構造の代表的なモデルとなります。
このモデルは序盤(エンコーダ)に畳み込み層を用いた特徴量抽出を行い、終盤(デコーダ)で確立マップを出力します。

UNetの特徴は、skip connectionでエンコーダとデコーダの出力をチャンネル方向に結合することです(下図の灰色矢印)。これにより物体の位置情報を保持して高精度の予測を可能としています。

(原著論文より)

データセット

データセットとしてKaggleの「UW-Madison GI Tract Image Segmentation」のデータセットを使用しました。本コンペティションはMRIの画像から小腸・大腸・胃の領域を予測し、その精度を競うものです。
comp.png

環境

  • Google Colaboratory Pro

コード

モジュールのimport

import pandas as pd
import numpy as np
import cv2
from matplotlib import pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torchvision.transforms import functional
import albumentations as A
!pip install -U segmentation-models-pytorch
import segmentation_models_pytorch as smp

データのダウンロード・前処理は省略します。データセットのtrain.csvにラベル情報が含まれており、これをpngファイルとして保存しています。

学習用にデータフレームを作成しました。学習用のデータフレームは250行×294列から成り、1行が1つの3Dデータに該当します。列は幅、高さ、深さなどのパラメータの他、各画像のパス、各ラベルのパスから構成されています。3Dデータの深さは80か144であるであるため、画像のパスとラベルのパスはそれぞれ80あるいは144個記載されています。(後のデータローダーで深さ方向に結合していきます。)

train_df = pd.read_csv("./train_3D144_df.csv")
val_df = pd.read_csv("./val_3D144_df.csv")
test_df = pd.read_csv("./test_3D144_df.csv")

print(train_df.shape)
print(train_df.columns)
実行結果
(250, 294)
Index(['id', 'caseday', 'imgpath_0', 'imgpath_1', 'imgpath_2',
       'imgpath_3', 'imgpath_4', 'imgpath_5', 'imgpath_6',
       ...
       'labelpath_138', 'labelpath_139', 'labelpath_140', 'labelpath_141',
       'labelpath_142', 'labelpath_143', 'depth', 'height', 'width', 'label'],
      dtype='object', length=294)

データローダーを作成します。
Albumentationsで画像の水増しを行いますが、複数枚の画像に同一の処理を行う方法が分からず、かなり苦戦しました。
こちらはまた別の記事でまとめます。
以下は80枚or144枚の画像にまとめて同一の処理を行うように記述しています。

additional_image_targets = {f"image{i+1}": "image" for i in range(80 - 1)}
additional_label_targets = {f"mask{i+1}": "mask" for i in range(80 - 1)}
additional_targets80 = dict(additional_image_targets, **additional_label_targets)

data_transforms80 = A.Compose([
      A.HorizontalFlip(p=0.5),
      A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
      A.OneOf([
          A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
          A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
      ], p=0.25),],
      additional_targets = additional_targets80, p=1.0)

additional_image_targets = {f"image{i+1}": "image" for i in range(144 - 1)}
additional_label_targets = {f"mask{i+1}": "mask" for i in range(144 - 1)}
additional_targets144 = dict(additional_image_targets, **additional_label_targets)

data_transforms144 = A.Compose([
      A.HorizontalFlip(p=0.5),
      A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
      A.OneOf([
          A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
          A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
      ], p=0.25),],
      additional_targets = additional_targets144, p=1.0)

以下、コードが冗長ですが次の処理を行っています。

1.各深さの画像とラベルを読込み、(256,256)にリサイズした後深さ方向に結合
2.Albumentationsの引数とするため辞書型のデータを作成
3.Albumentationsによるデータ拡張
4.変換後のデータを再度深さ方向に結合
5.画像は各RGBの値を0~1に変換(255で割る)
6.ラベルは0(background)、1(large_bowel)、2(small_bowel)、3(stomach)で表されているため、これをOne-Hot エンコーディング
7.画像・ラベルデータを(バッチサイズ、チャンネル数、深さ、高さ、幅)に変換

class Dataset(BaseDataset):
  def __init__(
      self,
      df,
      ):
    self.dimension_list = list(df.depth)
    self.imgpath_list = [[df.iloc[i,:][f"imgpath_{h}"] for h in range(self.dimension_list[i])] for i in range(len(df))]
    self.labelpath_list = [[df.iloc[i,:][f"labelpath_{h}"] for h in range(self.dimension_list[i])] for i in range(len(df))]
    self.transform80 = data_transforms80
    self.transform144 = data_transforms144

  def __getitem__(self, i):
    imgpathes = self.imgpath_list[i]
    labelpathes = self.labelpath_list[i]
    dim = self.dimension_list[i]
    for j in range(dim):  #画像とラベルデータの読み込み
      img = cv2.imread(imgpathes[j])
      img = cv2.resize(img, dsize = (256, 256))
      label = Image.open(labelpathes[j])
      label = np.asarray(label)
      label = cv2.resize(label, dsize = (256, 256))
      if j == 0:  #深さ方向に結合
        img_3D = [img]
        label_3D = [label]
      else:
        img_3D = np.vstack([img_3D, [img]])
        label_3D = np.vstack([label_3D, [label]])
      
    d1 = {"image": img_3D[0,:,:,:]}  #Albumentationsに代入する為の辞書型データを作成
    d2 = {f"image{i+1}": img_3D[i+1,:,:,:] for i in range(dim - 1)}
    d3 = {"mask": label_3D[0,:,:]}
    d4 = {f"mask{i+1}": label_3D[i+1,:,:] for i in range(dim - 1)}
    dic = dict(d1, **d2, **d3, **d4)

    if dim == 80:  #深さが80の場合
      transformed = self.transform80(**dic)
    else:  #深さが144の場合
      transformed = self.transform144(**dic)

    for j in range(dim):
      if j == 0:  #データ拡張後のデータを再度深さ方向に結合
        img_3D = [transformed["image"]]
        label_3D = [transformed["mask"]]
      else:
        img_3D = np.vstack([img_3D, [transformed[f"image{j}"]]])
        label_3D = np.vstack([label_3D, [transformed[f"mask{j}"]]])

    img_3D = img_3D/255  #RGBの値を0~1に
    img_3D = torch.from_numpy(img_3D.astype(np.float32)).clone()
    img_3D = img_3D.permute(3, 0, 1, 2)  #(チャンネル数、深さ、高さ、幅)に変換
    label_3D = torch.from_numpy(label_3D.astype(np.float32)).clone()
    label_3D = torch.nn.functional.one_hot(label_3D.long(), num_classes=4) #One-Hot エンコーディング
    label_3D = label_3D.to(torch.float32)
    label_3D = label_3D.permute(3, 0, 1, 2)  #(チャンネル数、深さ、高さ、幅)に変換
    data = {"img": img_3D, "label": label_3D}
    return data
  
  def __len__(self):
    return len(self.imgpath_list)


class valtest_Dataset(BaseDataset):  #Albumentationsによるデータ拡張を行わない
  def __init__(
      self,
      df,
      ):
    self.dimension_list = list(df.depth)
    self.imgpath_list = [[df.iloc[i,:][f"imgpath_{h}"] for h in range(self.dimension_list[i])] for i in range(len(df))]
    self.labelpath_list = [[df.iloc[i,:][f"labelpath_{h}"] for h in range(self.dimension_list[i])] for i in range(len(df))]

  def __getitem__(self, i):
    imgpathes = self.imgpath_list[i]
    labelpathes = self.labelpath_list[i]
    dim = self.dimension_list[i]
    for j in range(dim):
      img = cv2.imread(imgpathes[j])
      img = cv2.resize(img, dsize = (256, 256))
      label = Image.open(labelpathes[j])
      label = np.asarray(label)
      label = cv2.resize(label, dsize = (256, 256))
      if j == 0:
        img_3D = [img]
        label_3D = [label]
      else:
        img_3D = np.vstack([img_3D, [img]])
        label_3D = np.vstack([label_3D, [label]])

    img_3D = img_3D/255
    img_3D = torch.from_numpy(img_3D.astype(np.float32)).clone()
    img_3D = img_3D.permute(3, 0, 1, 2)
    label_3D = torch.from_numpy(label_3D.astype(np.float32)).clone()
    label_3D = torch.nn.functional.one_hot(label_3D.long(), num_classes=4)
    label_3D = label_3D.to(torch.float32)
    label_3D = label_3D.permute(3, 0, 1, 2)
    data = {"img": img_3D, "label": label_3D}
    return data
  
  def __len__(self):
    return len(self.imgpath_list)
BATCH_SIZE = 1  #2以上ではout of memoryのエラーが出る

train_dataset = Dataset(train_df)
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=4,
                          shuffle=True)

val_dataset = valtest_Dataset(val_df)
val_loader = DataLoader(val_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=4,
                          shuffle=True)

test_dataset = valtest_Dataset(test_df)
test_loader = DataLoader(test_dataset,
                          batch_size=1,
                          num_workers=4)

Unetを構築します。基本的な構造は前回の2DのUNetと同じで、各層を以下のように2Dから3Dに変更しています。
nn.Conv2d → nn.Conv3d、
nn.BatchNorm2d → nn.BatchNorm3d
nn.MaxPool2d → nn.MaxPool3d

また、out of memoryへの対策として、畳み込み第1層の出力チャンネル数を64から8に減らしています。

class TwoConvBlock_3D(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, middle_channels, kernel_size = 3, padding="same")
        self.bn1 = nn.BatchNorm3d(middle_channels)
        self.rl = nn.ReLU()
        self.conv2 = nn.Conv3d(middle_channels, out_channels, kernel_size = 3, padding="same")
        self.bn2 = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.rl(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.rl(x)
        return x

class UpConv_3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)
        self.bn1 = nn.BatchNorm3d(in_channels)
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 2, padding="same")
        self.bn2 = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        x = self.up(x)
        x = self.bn1(x)
        x = self.conv(x)
        x = self.bn2(x)
        return x

class UNet_3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.TCB1 = TwoConvBlock_3D(3, 8, 8)  #2D_UNetでは(3, 64, 64)
        self.TCB2 = TwoConvBlock_3D(8, 32, 32)
        self.TCB3 = TwoConvBlock_3D(32, 64, 64)
        self.TCB4 = TwoConvBlock_3D(64, 512, 512)
        self.TCB5 = TwoConvBlock_3D(512, 1024, 1024)

        self.TCB6 = TwoConvBlock_3D(1024, 512, 512)
        self.TCB7 = TwoConvBlock_3D(128, 64, 64)
        self.TCB8 = TwoConvBlock_3D(64, 32, 32)
        self.TCB9 = TwoConvBlock_3D(16, 8, 8)

        self.maxpool = nn.MaxPool3d(2, stride = 2)
        
        self.UC1 = UpConv_3D(1024, 512) 
        self.UC2 = UpConv_3D(512, 64) 
        self.UC3 = UpConv_3D(64, 32) 
        self.UC4 = UpConv_3D(32, 8)

        self.conv1 = nn.Conv3d(8, 4, kernel_size = 1)

    def forward(self, x):
        x = self.TCB1(x)
        x1 = x
        x = self.maxpool(x)

        x = self.TCB2(x)
        x2 = x
        x = self.maxpool(x)

        x = self.TCB3(x)
        x3 = x
        x = self.maxpool(x)

        x = self.TCB4(x)
        x4 = x
        x = self.maxpool(x)

        x = self.TCB5(x)

        x = self.UC1(x)
        x = torch.cat([x4, x], dim = 1)
        x = self.TCB6(x)

        x = self.UC2(x)
        x = torch.cat([x3, x], dim = 1)
        x = self.TCB7(x)

        x = self.UC3(x)
        x = torch.cat([x2, x], dim = 1)
        x = self.TCB8(x)

        x = self.UC4(x)
        x = torch.cat([x1, x], dim = 1)
        x = self.TCB9(x)

        x = self.conv1(x)

        return x

GPU、最適化アルゴリズムの設定を行います。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

unet = UNet_3D().to(device)
optimizer = optim.Adam(unet.parameters(), lr=0.001)

損失関数を設定します。損失はTversky LossとBCEWithLogits Lossの平均としました。これらの関数は損失関数内でソフトマックス関数を処理する為、UNetの最後にソフトマックス関数を適用していません。

TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
def criterion(pred,target):
    return 0.5*BCELoss(pred, target) + 0.5*TverskyLoss(pred, target)

学習を行います。

history = {"train_loss": []}
n = 0
m = 0

for epoch in range(15):
  train_loss = 0
  val_loss = 0

  unet.train()
  for i, data in enumerate(train_loader):
    inputs, labels = data["img"].to(device), data["label"].to(device)
    optimizer.zero_grad()
    outputs = unet(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    history["train_loss"].append(loss.item())
    n += 1
    if i % ((len(train_df)//BATCH_SIZE)//10) == (len(train_df)//BATCH_SIZE)//10 - 1:
      print(f"epoch:{epoch+1}  index:{i+1}  train_loss:{train_loss/n:.5f}")
      n = 0
      train_loss = 0

  unet.eval()
  with torch.no_grad():
    for i, data in enumerate(val_loader):
      inputs, labels = data["img"].to(device), data["label"].to(device)
      outputs = unet(inputs)
      loss = criterion(outputs, labels)
      val_loss += loss.item()
      m += 1
      if i % (len(val_df)//BATCH_SIZE) == len(val_df)//BATCH_SIZE - 1:
        print(f"epoch:{epoch+1}  index:{i+1}  val_loss:{val_loss/m:.5f}")
        m = 0
        val_loss = 0

  scheduler.step()
  torch.save(unet.state_dict(), f"./train_{epoch+1}.pth")
print("finish training")

損失の推移をプロットします。
縦軸がlossで横軸が各バッチを示します。

plt.figure()
plt.plot(history["train_loss"])
plt.xlabel('batch')
plt.ylabel('train_loss')

plt.figure()
plt.plot(history["val_loss"])
plt.xlabel('batch')
plt.ylabel('validation_loss')

trainloss.png
valloss.png

testデータに対して予測を行います。
validation用データセットの損失が最も低かったエポックの重みを使用します。

model = UNet_3D()
model.load_state_dict(torch.load("./train_13.pth"))
model.eval()

data = next(iter(test_loader))
inputs, labels = data["img"], data["label"]
outputs = model(inputs)

sigmoid = nn.Sigmoid()
outputs = sigmoid(outputs)
pred = torch.argmax(outputs, axis=1)
pred = torch.nn.functional.one_hot(pred.long(), num_classes=4).to(torch.float32)

元の画像と予測結果を表示します。

plt.figure()
plt.imshow(data["img"][0,:,88,:,:].permute(1, 2, 0))
plt.title("original_image")
plt.axis("off")

plt.figure()
classes = ["background","large_bowel","small_bowel","stomach"]
fig, ax = plt.subplots(2, 4, figsize=(15,8))
for i in range(2):
  for j, cl in enumerate(classes):
    if i == 0:
      ax[i,j].imshow(pred[0,88,:,:,j])
      ax[i,j].set_title(f"pred_{cl}")
      ax[i,j].axis("off")
    else:
      ax[i,j].imshow(data["label"][0,j,88,:,:])    
      ax[i,j].set_title(f"label_{cl}")
      ax[i,j].axis("off")

image3.png
image4.png

2DUNetほど高精度に予測できていません。画像の水増しにも限界があり、データの絶対量が足らないという印象です(畳み込み第1層の出力チャンネル数を減らしたことが原因の可能性もあります)。
2DのUNetとの精度の比較は、2.5DのUNet(2D_Convolutionと3D_Convolutionの組み合わせ)を実装する際に行います。

9
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
5