2
3

More than 1 year has passed since last update.

PyTorchでIoUを求める (torchmetricsのiouモジュールは廃止済)

Posted at

概要

Semantic Segmentationを実施する際、IoU (Intersection over Union)を計算すると思います。

2021年末迄はtorchmetrics.functionalからで、iouモジュールをインポートすれば計算できていたのですが、バージョンアップによりjaccard_indexに変更になっています。1

このことに気付かず最初少し苦戦したので、記事として記録しておきます。

公式ドキュメント: https://torchmetrics.readthedocs.io/en/stable/classification/jaccard_index.html

Jaccard Indexとは

IoUと定義は全く同じです。ある集合とA, 別の集合Bのjaccard indexを$J(A,B)$としたとき、

J(A,B)=\frac{|A \cap B|}{|A \cup B|}

と計算されます。

下記サイトが実装例も記載されており、分かりやすいです。

また、英語ではありますが、Wikiも掲載しておきます。

実装例

UNetの構築

UNet (バックボーン: Resnet18)を用いて、Semantic Segmentationを行います。

Unetは、segmentation_models_pytorchですでにモジュールとして準備されていますので、今回はこちらを使用します。

また、NNの構築にはPyTorchLightningを使用します。

3チャネルの画像を入力するとし、計12クラスのセグメンテーションを実施します。

なお、ここでは必要なライブラリをすでに読み込んでいて、dataloaderも作成後のコードから記載しています。

# segmentation_models_pytorchモジュールの読み込み

import segmentation_models_pytorch as smp

# UNetの構築

class Net_resnet18(pl.LightningModule):
    def __init__(self, in_channels=3, n_classes=12):
        super().__init__()
        
        self.unet = smp.Unet("resnet18", in_channels=in_channels, classes=n_classes, encoder_weights="imagenet")
        
    def forward(self, x):
        return self.unet(x)
    
    def training_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        self.log("train_acc", accuracy(y.softmax(dim=1), t), on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log("val_loss", loss, on_step=True, on_epoch=True)
        self.log("val_acc", accuracy(y.softmax(dim=1), t), on_step=True, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
        

学習実行

上記で定義したUNetをインスタンス化し、学習実行します。

PyTorchLightningの実装に関しては、こちらの記事が分かりやすいです。


pl.seed_everything(0)
net_resnet18 = Net_resnet18(n_classes=12)
trainer = pl.Trainer(max_epochs=30, gpus=1, deterministic=False)
trainer.fit(net_resnet18, train_loader, val_loader)

IoUの計算

torchmetrics.functionalからjaccard_indexをインポートし、計算します。

なお、functionalからのインポートであれば、特にjaccard_indexをインスタンス化せずともクラスメソッドとして使用できるので、便利です。


from torchmetrics.functional import jaccard_index

# validation_loaderのデータから、目標値と予測値を計算
y, t = predict(net_resnet18, val_loader, device="cpu")

print("pixel accuracy: ", accuracy(y, t))
print("IoU: ", jaccard_index(y, t, num_classes=12))

  1. https://github.com/Lightning-AI/metrics/pull/662/commits/2e96b03de0e0b3a8ddb5010df1df9f5e3197238b

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3