2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのMaskR-CNNにFocal_lossを導入する

Posted at

この記事でやること

PyTorchでMaskR-CNNを実装しているときに、「focal loss」を導入したいなぁと思った。
そこで、頑張って導入しました、という記事です。focal loss自体の解説記事ではないです。

モチベーション

自分が使っているデータは偏りが大きいので、RetinaNetのようにfocal lossを導入すれば識別率が良くなるのではないか、と考えた。そこで、「MaskR-CNNにfocal lossを導入したい」と思うに至った。

どこに分類のloss関数があるのか

探したらここにありました。

image.png

PyTrochのMaskR-CNN及びFasterR-CNNのクラス分類では__cross_entropy__が使われているみたいです。

実装

このフォーラムを参考に実装しました。
コードそのものはこれです。

focal_loss
class FocalLoss(nn.Module):
    
    def __init__(self, weight=None,
             gamma=2.5, reduction='mean'):
        nn.Module.__init__(self)
        self.gamma = gamma
        self.reduction = reduction
        self.weight=weight
        
    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob, 
            target_tensor, 
            weight=self.weight,
            reduction = self.reduction
        )

これをfastrcnn_lossに導入します。ここで、ファイル名をmy_fastrcnn_loss_with_focal_loss.pyとして保存しておきます。

my_fastrcnn_loss_with_focal_loss.py
import torch.nn.functional as F
from torch import nn

class FocalLoss(nn.Module):
    
    def __init__(self, weight=None,
                 gamma=2.5, reduction='mean'):
        nn.Module.__init__(self)
        self.weight=weight
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob, 
            target_tensor, 
            weight=self.weight,
            reduction = self.reduction
        )

def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
    """
    Computes the loss for Faster R-CNN.
    Args:
        class_logits (Tensor)
        box_regression (Tensor)
        labels (list[BoxList])
        regression_targets (Tensor)
    Returns:
        classification_loss (Tensor)
        box_loss (Tensor)
    """

    labels = torch.cat(labels, dim=0)
    regression_targets = torch.cat(regression_targets, dim=0)

    #この部分をfocal_lossへ変更する
    #classification_loss = F.cross_entropy(class_logits, labels)
    focal=FocalLoss()
    classification_loss = focal(class_logits, labels)
    #変更はここまで

    # get indices that correspond to the regression targets for
    # the corresponding ground truth labels, to be used with
    # advanced indexing
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
    labels_pos = labels[sampled_pos_inds_subset]
    N, num_classes = class_logits.shape
    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)

    box_loss = F.smooth_l1_loss(
        box_regression[sampled_pos_inds_subset, labels_pos],
        regression_targets[sampled_pos_inds_subset],
        beta=1 / 9,
        reduction='sum',
    )
    box_loss = box_loss / labels.numel()

    return classification_loss, box_loss

この後は前回と同じようにMaskR-CNNを実装した後、特に

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

を実装した後にコードを入れていきます。(途中までのコードは公式のチュートリアルです)

mask_rcnn_model.py
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# ここまででネットワークを作る
# ここから損失関数を変更していく

from my_fastrcnn_loss_with_focal_loss import fastrcnn_loss
torchvision.models.detection.roi_heads.fastrcnn_loss=fastrcnn_loss

とします。これで、Mask R-CNNにfocal lossを実装することができました。(できたはず)

余談

weightを設定するときは

class FocalLoss(nn.Module):
    
    def __init__(self, 
                 gamma=2.5, reduction='mean'):
        nn.Module.__init__(self)
        self.weight=self.weight = torch.tensor([1. ,100. ,2. ,2. ,1.]).cuda()
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, input_tensor, target_tensor):
    ...

とcuda coreに乗せて設定することでうまくいった。

終わりに

実際、focal lossの導入に関しては、Region Proposal Networkを使うモデルにはあんまり効果がないんじゃないか、という懸念もありますが、どうなるんでしょう。
ちょっと疑問が残ってしまいましたが、実装自体はできたと思います。よかった。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?