0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Grand-challenge向けの環境準備 Part.4 (Google Colab上で学習プログラムを動かす)

Posted at

Part1Part2Part3から引き続き、Grand challengeのチュートリアルになっている 網膜画像から血管の領域をセグメンテーションする例を動かしてみた際のメモです。

Grand Challengeのコンペに参加する際は、以下のように進めたいと考えています。

  1. Google Colab上で学習プログラムを作って機械学習の重みファイルを作成する (→この記事でやり方を確認します)
  2. Google Colab上で推論プログラムを作って推論できるかを確認する (→Part3でやり方を確認しました)
  3. ローカルPCで推論プログラムGrand Challengeに提出できるように整える(→Part2でやり方を確認しました)

6. Google Colabでのモデルの作成・学習

ここでは、Google Colab上で学習を行い、学習結果の重みファイル(前記事のbest_metric_model_segmentation2d_dict.pthに相当)の算出を行います。具体的には、drive-vessel-unet -> train.py の内容をGoogle Colabで動かします。

Grand challenge特有のモジュールである evalutils は推論のみをサポートしているようです。学習プログラムは evalutils は使わないので、Grand challenge特有の制約は特にありません。

6.1 必要なファイルのGoolge driveへの保存

学習用データ(datasets)とrequirement.txtは 5.1項でGoogle Driveに保存した通りです。
train.py は、dataloader.pyを読み込んで使っています。dataloader.pyを読み出せるように、Google Driveに置きました。

Google Driveの保存先
マイドライブ/
└ Colab Notebooks/
	└ drive-vessels-net/
		├ datasets/
			├ (省略)
		├ requirement.txt
		└ dataloader.py

6.2 Google ColabからGoogle Driveのマウント

5.2項と同じです。

Colabセル1
#Google driveをマウントする
from google.colab import drive
drive.mount('./gdrive')
drive_root_dir="/content/gdrive/MyDrive/Colab Notebooks/drive-vessels-unet/"

6.3 必要なモジュール等のpip/import

train.py で使っている必要なモジュール等のpip/importを行います。

6.3.1 pip

5.3.1項と同じです。

Colabセル2
!pip install -r "/content/gdrive/MyDrive/Colab Notebooks/drive-vessels-unet/requirements.txt"

6.3.2 import

必要なモジュールをimportします。train.py の冒頭でimportしている部分をコピーしました。ただし、dataloader.pyを読み込むためにPathを追加しています。

Colabセル2
import logging
import os
import sys

from glob import glob
import numpy as np

import torch
import torch.nn.functional as F

# dataloaderを読み込むためにPathを追加する
sys.path.append("/content/gdrive/MyDrive/Colab Notebooks/drive-vessels-unet/") 
from dataloader import DRIVEDataset, Rescale, ToTensor, Normalize, WeightMap
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.visualize import plot_2d_or_3d_image

6.4 学習部分

train.pymain()部分をコピーして学習を実行させます。Google Colabで動かすためにファイルパスを修正しました。

ただし、後述のしますがこのままでは学習がうまくいかず修正が必要でした。

Colabセル3
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # このファイルパスを修正しています。
    IMAGE_ROOT = drive_root_dir+"datasets/training/images"
    LABEL_ROOT = drive_root_dir+"datasets/training/1st_manual"

    images = glob(os.path.join(IMAGE_ROOT, "*training.tif"))        # 要修正点1
    labels = glob(os.path.join(LABEL_ROOT, "*manual1.gif"))

    data_transform = transforms.Compose([        # 要修正点2
        Rescale((512, 512)),
        Normalize(),
        WeightMap(),
        ToTensor(),
    ])

    train_ds = DRIVEDataset(images[:-10], labels[:-10], transform=data_transform)
    valid_ds = DRIVEDataset(images[-10:], labels[-10:], transform=data_transform)

    train_loader = DataLoader(
        train_ds, batch_size=4, shuffle=True, num_workers=0, pin_memory=True,
    )

    val_loader = DataLoader(
        valid_ds, batch_size=2, shuffle=True, num_workers=0, pin_memory=True,
    )

    # create UNet and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=3,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = F.binary_cross_entropy_with_logits        # 要修正点3
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-5)

    # start a typical PyTorch training
    epochs_total = 1000
    val_interval = 1
    best_loss = np.inf
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(epochs_total):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epochs_total}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels, weights = (
                batch_data["img"].to(device),
                batch_data["seg"].to(device),
                batch_data["map"].to(device),
            )
            optimizer.zero_grad()
            outputs = model(inputs).squeeze()
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                epoch_val_loss = []
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels, val_weights = (
                        val_data["img"].to(device),
                        val_data["seg"].to(device),
                        val_data["map"].to(device),
                    )
                    val_outputs = model(val_images).squeeze()
                    loss = loss_function(val_outputs, val_labels)
                    epoch_val_loss.append(loss.item())
                epoch_val_loss = np.array(epoch_val_loss).mean()
                if epoch_val_loss < best_loss:
                    best_loss = epoch_val_loss
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model.state_dict(), 
                        drive_root_dir+"best_metric_model_segmentation2d_dict.pth" # このファイルパスも修正しています。
                    )
                    print("saved new best metric model")
                print(
                    "current epoch: {} current val loss: {:.4f} best val loss: {:.4f} at epoch {}".format(
                        epoch + 1, epoch_val_loss, best_loss, best_metric_epoch
                    )
                )
                writer.add_scalar("val_loss", epoch_val_loss, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(
                    val_outputs, epoch + 1, writer, index=0, tag="output"
                )

    print(f"train completed, best_loss: {best_loss:.4f} at epoch: {best_metric_epoch}")
    writer.close()


if __name__ == "__main__":

    main()

私がGoogle Colabで実行した場合は、GPUモードで約45分で学習が終わりました。

(参考)Colab出力
Colab出力
MONAI version: 0.8.1
Numpy version: 1.21.5
Pytorch version: 1.10.0+cu111
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 71ff399a3ea07aef667b23653620a290364095b1

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.0.2
scikit-image version: 0.18.3
Pillow version: 7.1.2
Tensorboard version: 2.8.0
gdown version: 4.2.2
TorchVision version: 0.11.1+cu111
tqdm version: 4.63.0
lmdb version: 0.99
psutil version: 5.4.8
pandas version: 1.3.5
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

/content/gdrive/MyDrive/Colab Notebooks/drive-vessels-unet/datasets/training/images
/content/gdrive/MyDrive/Colab Notebooks/drive-vessels-unet/datasets/training/1st_manual
----------
epoch 1/1000
1/2, train_loss: 0.9578
2/2, train_loss: 0.9108
3/2, train_loss: 0.8854
epoch 1 average loss: 0.9180
saved new best metric model
current epoch: 1 current val loss: 0.8503 best val loss: 0.8503 at epoch 1
(省略)
----------
epoch 1000/1000
1/2, train_loss: 0.0848
2/2, train_loss: 0.0748
3/2, train_loss: 0.0842
epoch 1000 average loss: 0.0813
current epoch: 1000 current val loss: 0.5436 best val loss: 0.2660 at epoch 743
train completed, best_loss: 0.2660 at epoch: 743

しかし、この学習した重みファイルを使ってをPart3の通り推論するとテストデータに対して真っ黒な画像が出力されました。
この学習プログラムにはいくつか問題がありました。

修正点1 : 画像ファイル名の一覧を取得する際にソートを追加

初めに、出力画像が真っ黒だったので、間違って真っ黒な画像を学習していないかと疑い、学習に使っている画像が正しく読み込めているかを確認しました。(確認方法は修正点2のコードで、「#デバック用に表示する場合」とコメントアウトしている部分をコメントアウトを解除することで表示しました。)

すると、画像は読み込めているものの、入力画像と正解画像の組み合わせが間違って読み込まれているときがあることがわかりました。
原因は、glob()で画像ファイル名の一覧を取得する際に、順番が保証されないため、入力画像と正解画像の順番が一致しない場合があるためのようです(参考)。修正のためsorted()を追加しました。

Colabセル3(修正部のみ抜粋)
    images = sorted(glob(os.path.join(IMAGE_ROOT, "*training.tif")))  # sortedの追加が必要
    labels = sorted(glob(os.path.join(LABEL_ROOT, "*manual1.gif")))

修正点2 : 学習用画像の拡張(Data augmentation)の追加

学習・テスト用のデータはDRIVE databaseのものを使っていますが、学習用のデータは20枚しかありません。しかも、この学習プログラムでは20枚中10枚をValidation(学習がうまくいっているかの検証用)に使っているため、学習に使っているのは10枚しかありません。
学習用の画像が不足していますので、Data augmentationを行うことにしました。

Data augmentation には色々な方法がありますが、今回は以下を採用することしました。

  • ランダムな水平方向の反転
  • ランダムな垂直方向の反転
  • ランダムな回転(±180度)、ランダムなXY方向のズレ(XYともに±5%)、ランダムな拡大・縮小(±10%)
  • ランダムな画像の明るさの変更(±40%)、ランダムなコントラストの変更(±50%)

Pytorchでは、通常は、transforms.Composetransforms.RandomHorizontalFlip()などの用意された関数を追記するだけでData augmentationが実現できます。(参考
しかし、今回のプログラムではデータセットをDRIVEDatasetという独自クラスを使っているため、Data augmentationを定義してやる必要があります。

まず、torchvision.transforms.functionalrandomを使いたいためimportを追加します。

Colabセル2(追記部のみ抜粋)
# Data augmentationのために追加
import torchvision.transforms.functional as TF
import random

以下の関数クラスを追記しました。

  • Data Augmentationの方法ごとに関数を分けました。各関数に共通する部分はdataloader.pyclass WeightMap(object):class Rescale(object):class Normalize(object):を参考にしました。
  • Data Augmentationの方法ごとの処理は、PyTorch公式ドキュメントのFunctional Transformsを参考にしました。
Colabセル3(追記部のみ抜粋)
# data augmentation用のクラス
class RandomHorizontalFlip(object):
    def __call__(self, sample):
        image, label = sample["img"], sample["seg"]
        if random.random() > 0.5:
            image = TF.hflip(image)
            label = TF.hflip(label)        
        return {"img": image, "seg": label}

class RandomVerticalFlip(object):
    def __call__(self, sample):
        image, label = sample["img"], sample["seg"]
        if random.random() > 0.5:
            image = TF.vflip(image)
            label = TF.vflip(label)
        return {"img": image, "seg": label}

class ColorJitter(object):
    def __call__(self, sample):
        image, label = sample["img"], sample["seg"]
        # Color jitteは元画像にだけ掛ける。正解画像には掛けない。
        image = transforms.ColorJitter(brightness=0.4, contrast=0.5)(image)
        return {"img": image, "seg": label}

class RandomAffine(object):
    def __call__(self, sample):
        image, label = sample["img"], sample["seg"]
        angle = random.randint(-180, 180)
        x_translate = random.uniform(-0.05, 0.05)
        y_translate = random.uniform(-0.05, 0.05)
        scale = random.uniform(0.9, 1.1)

        image = TF.affine(image, angle, [x_translate, y_translate], scale, 0.0)
        label = TF.affine(label, angle, [x_translate, y_translate], scale, 0.0)

        #デバック用に表示する場合
#        display(image)
#        display(label)

        return {"img": image, "seg": label}

追記した関数がtransforms.Composeのなかに追加します。学習時にのみData Aumentationを行い、Validation時には行わないように設定しました。transforms.Composeはその順番通りにデータ変換が適用されます。

Colabセル3(修正部のみ抜粋)
def main():
#    (前略)
    data_transform_for_train = transforms.Compose([
        Rescale((512, 512)),
        RandomHorizontalFlip(),   #追加
        RandomVerticalFlip(),     #追加
        RandomAffine(),           #追加
        ColorJitter(),            #追加
        Normalize(),
        WeightMap(),
        ToTensor(),
    ])

    data_transform = transforms.Compose([
        Rescale((512, 512)),
        Normalize(),
        WeightMap(),
        ToTensor(),
    ])

    train_ds = DRIVEDataset(images[:-10], labels[:-10], transform=data_transform_for_train)
    valid_ds = DRIVEDataset(images[-10:], labels[-10:], transform=data_transform)
#    (後略)

Data Augmentation を行った場合の学習結果を、Part3の通り推論すると以下のような出力となりました。それっぽい画像が出ていますが、drive-vessel-unet -> best_metric_model_segmentation2d_dict.pth の学習済み重みを使って推論した場合と比べると、細かい血管の推定がうまくいっていないようです。

image.png

修正点3 : 損失関数の変更

網膜画像では大部分が血管以外なので、画像全体を血管以外と判定すれば75%以上の面積で正解となります。こういった分類するクラス(血管以外と血管)のピクセル数の出現頻度が偏っていると、損失関数がCross Entropyではうまく学習できない事が多いそうです。(参考Webサイト)

ここまで損失関数としてBinary Cross Entropyを使ってきましたが、データの不均衡に強いDICE関数に変更します。
こちらこちらのWebサイトを参考にして、DICE関数を以下のように定義しました。

Colabセル3(追記部のみ抜粋)
def dice_loss(inputs, targets, smooth=1):
    
    #comment out if your model contains a sigmoid or equivalent activation layer
    inputs = F.sigmoid(inputs)       
    #flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    
    intersection = (inputs * targets).sum()                            
    dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
    return dice_loss

損失関数の設定部分を書き換えます。

Colabセル3(追記部のみ抜粋)
def main():
#    (前略)
#    loss_function = F.binary_cross_entropy_with_logits
    loss_function = dice_loss
#    (後略)
修正したColabセル3の全体
Colabセル3(修正後の全体)
#損失関数
def dice_loss(inputs, targets, smooth=1):
    
    #comment out if your model contains a sigmoid or equivalent activation layer
    inputs = F.sigmoid(inputs)       
    #flatten label and prediction tensors
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    
    intersection = (inputs * targets).sum()                            
    dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
    return dice_loss


# data augmentation用のクラス
class RandomHorizontalFlip(object):
    def __call__(self, sample):
        image, label = sample["img"], sample["seg"]
        if random.random() > 0.5:
            image = TF.hflip(image)
            label = TF.hflip(label)        
        return {"img": image, "seg": label}

class RandomVerticalFlip(object):
    def __call__(self, sample):
        image, label = sample["img"], sample["seg"]
        if random.random() > 0.5:
            image = TF.vflip(image)
            label = TF.vflip(label)
        return {"img": image, "seg": label}

class ColorJitter(object):
    def __call__(self, sample):
        image, label = sample["img"], sample["seg"]
        # Color jitteは元画像にだけ掛ける。正解画像には掛けない。
        image = transforms.ColorJitter(brightness=0.4, contrast=0.5)(image)
        return {"img": image, "seg": label}

class RandomAffine(object):
    def __call__(self, sample):
        image, label = sample["img"], sample["seg"]
        angle = random.randint(-180, 180)
        x_translate = random.uniform(-0.05, 0.05)
        y_translate = random.uniform(-0.05, 0.05)
        scale = random.uniform(0.9, 1.1)

        image = TF.affine(image, angle, [x_translate, y_translate], scale, 0.0)
        label = TF.affine(label, angle, [x_translate, y_translate], scale, 0.0)

        #デバック用に表示する
#        display(image)
#        display(label)

        return {"img": image, "seg": label}

def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    IMAGE_ROOT = drive_root_dir+"datasets/training/images"
    LABEL_ROOT = drive_root_dir+"datasets/training/1st_manual"

    print(IMAGE_ROOT)
    print(LABEL_ROOT)

    images = sorted(glob(os.path.join(IMAGE_ROOT, "*training.tif")))  # sortedの追加が必要
    labels = sorted(glob(os.path.join(LABEL_ROOT, "*manual1.gif")))

    data_transform_for_train = transforms.Compose([
        Rescale((512, 512)),
        RandomHorizontalFlip(),   #追加
        RandomVerticalFlip(),     #追加
        RandomAffine(),           #追加
        ColorJitter(),            #追加
        Normalize(),
        WeightMap(),
        ToTensor(),
    ])

    data_transform = transforms.Compose([
        Rescale((512, 512)),
        Normalize(),
        WeightMap(),
        ToTensor(),
    ])

    train_ds = DRIVEDataset(images[:-10], labels[:-10], transform=data_transform_for_train)
    valid_ds = DRIVEDataset(images[-10:], labels[-10:], transform=data_transform)

    train_loader = DataLoader(
        train_ds, batch_size=4, shuffle=True, num_workers=0, pin_memory=True,
    )

    val_loader = DataLoader(
        valid_ds, batch_size=2, shuffle=True, num_workers=0, pin_memory=True,
    )

    # create UNet and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=3,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
#    loss_function = F.binary_cross_entropy_with_logits
    loss_function = dice_loss
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-5)

    # start a typical PyTorch training
    epochs_total = 10
    val_interval = 1
    best_loss = np.inf
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(epochs_total):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epochs_total}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels, weights = (
                batch_data["img"].to(device),
                batch_data["seg"].to(device),
                batch_data["map"].to(device),
            )
            optimizer.zero_grad()
            outputs = model(inputs).squeeze()
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                epoch_val_loss = []
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels, val_weights = (
                        val_data["img"].to(device),
                        val_data["seg"].to(device),
                        val_data["map"].to(device),
                    )
                    val_outputs = model(val_images).squeeze()
                    loss = loss_function(val_outputs, val_labels)
                    epoch_val_loss.append(loss.item())
                epoch_val_loss = np.array(epoch_val_loss).mean()
                if epoch_val_loss < best_loss:
                    best_loss = epoch_val_loss
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model.state_dict(), 
                        drive_root_dir+"best_metric_model_segmentation2d_dict.pth" # このファイルパスも修正しています。
                    )
                    print("saved new best metric model")
                print(
                    "current epoch: {} current val loss: {:.4f} best val loss: {:.4f} at epoch {}".format(
                        epoch + 1, epoch_val_loss, best_loss, best_metric_epoch
                    )
                )
                writer.add_scalar("val_loss", epoch_val_loss, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(
                    val_outputs, epoch + 1, writer, index=0, tag="output"
                )

    print(f"train completed, best_loss: {best_loss:.4f} at epoch: {best_metric_epoch}")
    writer.close()


if __name__ == "__main__":

    main()

損失関数の変更を行って学習を再度行い、得られた重みファイルを使ってPart3の通り推論すると以下のような出力となりました。細かい血管部分も血管と判定するようになりました。

image.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?