概要
Semantic Segmentationを実施する際、IoU (Intersection over Union)を計算すると思います。
2021年末迄はtorchmetrics.functional
からで、iou
モジュールをインポートすれば計算できていたのですが、バージョンアップによりjaccard_index
に変更になっています。1
このことに気付かず最初少し苦戦したので、記事として記録しておきます。
公式ドキュメント: https://torchmetrics.readthedocs.io/en/stable/classification/jaccard_index.html
Jaccard Indexとは
IoUと定義は全く同じです。ある集合とA, 別の集合Bのjaccard indexを$J(A,B)$としたとき、
J(A,B)=\frac{|A \cap B|}{|A \cup B|}
と計算されます。
下記サイトが実装例も記載されており、分かりやすいです。
また、英語ではありますが、Wikiも掲載しておきます。
実装例
UNetの構築
UNet (バックボーン: Resnet18)を用いて、Semantic Segmentationを行います。
Unetは、segmentation_models_pytorch
ですでにモジュールとして準備されていますので、今回はこちらを使用します。
また、NNの構築にはPyTorchLightningを使用します。
3チャネルの画像を入力するとし、計12クラスのセグメンテーションを実施します。
なお、ここでは必要なライブラリをすでに読み込んでいて、dataloaderも作成後のコードから記載しています。
# segmentation_models_pytorchモジュールの読み込み
import segmentation_models_pytorch as smp
# UNetの構築
class Net_resnet18(pl.LightningModule):
def __init__(self, in_channels=3, n_classes=12):
super().__init__()
self.unet = smp.Unet("resnet18", in_channels=in_channels, classes=n_classes, encoder_weights="imagenet")
def forward(self, x):
return self.unet(x)
def training_step(self, batch, batch_idx):
x, t = batch
y = self(x)
loss = F.cross_entropy(y, t)
self.log("train_loss", loss, on_step=True, on_epoch=True)
self.log("train_acc", accuracy(y.softmax(dim=1), t), on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, t = batch
y = self(x)
loss = F.cross_entropy(y, t)
self.log("val_loss", loss, on_step=True, on_epoch=True)
self.log("val_acc", accuracy(y.softmax(dim=1), t), on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
学習実行
上記で定義したUNetをインスタンス化し、学習実行します。
PyTorchLightningの実装に関しては、こちらの記事が分かりやすいです。
pl.seed_everything(0)
net_resnet18 = Net_resnet18(n_classes=12)
trainer = pl.Trainer(max_epochs=30, gpus=1, deterministic=False)
trainer.fit(net_resnet18, train_loader, val_loader)
IoUの計算
torchmetrics.functional
からjaccard_index
をインポートし、計算します。
なお、functionalからのインポートであれば、特にjaccard_indexをインスタンス化せずともクラスメソッドとして使用できるので、便利です。
from torchmetrics.functional import jaccard_index
# validation_loaderのデータから、目標値と予測値を計算
y, t = predict(net_resnet18, val_loader, device="cpu")
print("pixel accuracy: ", accuracy(y, t))
print("IoU: ", jaccard_index(y, t, num_classes=12))