この記事でやること
PyTorchでMaskR-CNNを実装しているときに、「focal loss」を導入したいなぁと思った。
そこで、頑張って導入しました、という記事です。focal loss自体の解説記事ではないです。
モチベーション
自分が使っているデータは偏りが大きいので、RetinaNetのようにfocal lossを導入すれば識別率が良くなるのではないか、と考えた。そこで、「MaskR-CNNにfocal lossを導入したい」と思うに至った。
どこに分類のloss関数があるのか
探したらここにありました。
PyTrochのMaskR-CNN及びFasterR-CNNのクラス分類では__cross_entropy__が使われているみたいです。
実装
このフォーラムを参考に実装しました。
コードそのものはこれです。
class FocalLoss(nn.Module):
def __init__(self, weight=None,
gamma=2.5, reduction='mean'):
nn.Module.__init__(self)
self.gamma = gamma
self.reduction = reduction
self.weight=weight
def forward(self, input_tensor, target_tensor):
log_prob = F.log_softmax(input_tensor, dim=-1)
prob = torch.exp(log_prob)
return F.nll_loss(
((1 - prob) ** self.gamma) * log_prob,
target_tensor,
weight=self.weight,
reduction = self.reduction
)
これをfastrcnn_lossに導入します。ここで、ファイル名をmy_fastrcnn_loss_with_focal_loss.pyとして保存しておきます。
import torch.nn.functional as F
from torch import nn
class FocalLoss(nn.Module):
def __init__(self, weight=None,
gamma=2.5, reduction='mean'):
nn.Module.__init__(self)
self.weight=weight
self.gamma = gamma
self.reduction = reduction
def forward(self, input_tensor, target_tensor):
log_prob = F.log_softmax(input_tensor, dim=-1)
prob = torch.exp(log_prob)
return F.nll_loss(
((1 - prob) ** self.gamma) * log_prob,
target_tensor,
weight=self.weight,
reduction = self.reduction
)
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Computes the loss for Faster R-CNN.
Args:
class_logits (Tensor)
box_regression (Tensor)
labels (list[BoxList])
regression_targets (Tensor)
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
#この部分をfocal_lossへ変更する
#classification_loss = F.cross_entropy(class_logits, labels)
focal=FocalLoss()
classification_loss = focal(class_logits, labels)
#変更はここまで
# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.where(labels > 0)[0]
labels_pos = labels[sampled_pos_inds_subset]
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
box_loss = F.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
reduction='sum',
)
box_loss = box_loss / labels.numel()
return classification_loss, box_loss
この後は前回と同じようにMaskR-CNNを実装した後、特に
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
を実装した後にコードを入れていきます。(途中までのコードは公式のチュートリアルです)
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# ここまででネットワークを作る
# ここから損失関数を変更していく
from my_fastrcnn_loss_with_focal_loss import fastrcnn_loss
torchvision.models.detection.roi_heads.fastrcnn_loss=fastrcnn_loss
とします。これで、Mask R-CNNにfocal lossを実装することができました。(できたはず)
余談
weightを設定するときは
class FocalLoss(nn.Module):
def __init__(self,
gamma=2.5, reduction='mean'):
nn.Module.__init__(self)
self.weight=self.weight = torch.tensor([1. ,100. ,2. ,2. ,1.]).cuda()
self.gamma = gamma
self.reduction = reduction
def forward(self, input_tensor, target_tensor):
...
とcuda coreに乗せて設定することでうまくいった。
終わりに
実際、focal lossの導入に関しては、Region Proposal Networkを使うモデルにはあんまり効果がないんじゃないか、という懸念もありますが、どうなるんでしょう。
ちょっと疑問が残ってしまいましたが、実装自体はできたと思います。よかった。