4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

論文読み "Class-Balanced Loss Based on Effective Number of Samples"

Posted at

はじめに

クラス間の出現頻度にばらつきがあるデータセットでは、高頻度クラスが台頭し低頻度クラスの予測精度が相対的に下がる問題が生まれます。この問題に対して損失関数への工夫という切り口から改善を図ったCVPR2019の論文、「Class-Balanced Loss Based on Effective Number of Samples」について解説します。

1. 論文解説

概要

不均衡データの分類問題に対しては一般に、データのresampling・損失関数のreweightingという二つの手法がとられます。前者のresamplingは少ないデータを水増しするため過学習を起こしやすいという問題があり、近年ではreweightingの方に焦点が当てられることが多いです。
reweightingは、「出現数が全然違うクラスを均等に学習しても上手くいかないから、クラスごとにいい感じの重みをつけて結果的にバランスが取れるように学習しよう」という試みです。従来はこの「いい感じの重み」には各クラスのサンプル数(=データセットでの出現数)の逆数が使われていたのですが、データセットが大規模になると効果が出ないという問題がありました。
そこでこの論文では

  • 「サンプル数」の定義を再考して新たな損失関数の重みを提案
  • 提案手法を用いて不均衡データセットにおける分類精度を改善

しました。次章以降で手法の中身について詳しく見ていきます。

提案手法

Effective Number

最初にこの論文のキーワードである「Effective Number」の考え方について説明します。上でも述べたように、既存手法では各クラスの出現数の逆数が重みとして用いられていました。しかしサンプルの中には、特徴量空間で他と類似しているサンプルやAugmentationで複製したサンプルがあり、これらは実質的には「被り」であり一つの独立したデータとしてカウントするのはどうなのか?と言うのがこの論文の注目点です。そこで、この「被り」サンプルは無視しユニークなサンプルだけを数えて「サンプル数」とみなすことにし、これをEffective Numberと読んでいます。

ではこのEffective Numberはどう計算すればいいのでしょうか。結論から行くと、出現数$n$のクラスのEffective Numberは以下の$E_n$のように定義されます。

$$E_n = \frac{\left(1-\beta^n\right)}{1-\beta},\tag{1}$$$$
\beta = \frac{\left(N-1\right)}{N}\tag{2}$$

ここで$n$は「被り」も含めた文字通りの出現数、$N$はEffective Numberの上限を指します。直感的には$\beta$が$E_n$の増加速度をコントロールするハイパーパラメータで、$\beta=0$(N=1を意味:全サンプル被り)のとき$E_n=1$、$\beta\rightarrow1$(N→∞を意味:全サンプル独立)のとき$E_n\rightarrow n$となります。

Class-Balanced Loss

分類ベクトルを$\mathbf{p} \in \mathbb{R}^C$、正解ラベルを$y \in \lbrace 1,2, \dots, C\rbrace$として、ロス関数を$\mathcal{L}(\mathbf{p}, y)$とします。このとき提案手法でのClass-Balancedロスは、Effective Numberを用いて以下のように計算できます。

$$\mathrm{CB}(\mathbf{p}, y) = \frac{1}{E_{n_y}}\mathcal{L}(\mathbf{p}, y) = \frac{1-\beta}{1-\beta^{n_y}}\mathcal{L}(\mathbf{p}, y)$$

式を見てもわかるように、この手法はロス関数やモデル構造に非依存に適用が可能です。本論文では従来のロス関数であるSoftmax Cross-Entoropy Loss、Sigmoid Cross-Entoropy Loss、Focal LossにそれぞれこのClass-Blanced処理を加えて実験を行なっています。

2. 実験

論文の手法を検証するため、本記事では実際にClass-Balanced LossをPytorchで実装しCIFAR10におけるクラス分類の精度評価を行いました。

Class-Balanced Lossの実装

コードベース全体はこちらで公開しております。(ちなみに論文の公式コードはこちらになります。)
本記事では、論文の肝であるClass-Balanced Lossについての実装方法について説明します。

Effective Numberを使った重みは以下のget_loss_func関数のように計算します。ここでnum_per_clsはクラスごとの出現数を表し、cfg.betaは$\beta$を指します。
全体の損失関数CombinedLossは分類ロス(Softmax Cross-Entoropy LossまたはSigmoid Cross-Entoropy LossまたはFocal Loss)に正則化項を加えたものです。

import torch
import torch.nn as nn
import torch.nn.functional as F

def get_loss_func(cfg, num_per_cls):
    weight = None
    if cfg.beta:
        weight = (1. - cfg.beta) / (1. - torch.pow(cfg.beta, torch.tensor(num_per_cls)))
        weight = weight / torch.sum(weight) * len(num_per_cls)
    return CombinedLoss(cfg, weight)

class CombinedLoss(nn.Module):
    def __init__(self, cfg, weight=None):
        super(CombinedLoss, self).__init__()
        self.weight_decay = cfg.weight_decay
        self.loss_type = cfg.loss_type
        if cfg.loss_type == "softmax":
            self.cls_loss = SoftmaxCrossEntorpyLoss(weight)
        elif cfg.loss_type == "sigmoid":
            self.cls_loss = SigmoidCrossEntropyLoss(weight)
        else:
            self.cls_loss = FocalLoss(cfg.gamma, weight)
        self.cls_loss.cuda()
    
    def forward(self, out, label, named_parameters):
        loss_items = {}
        loss_items["classification"] = self.cls_loss(out, label)
        loss_items["regularization"] = self.weight_decay * sum([torch.norm(param) for (name, param) in named_parameters if self.loss_type == "softmax" or "linear.bias" not in name])
        return loss_items

分類ロスはSoftmax Cross-Entoropy Loss、Sigmoid Cross-Entoropy Loss、Focal Lossの三種類を実装しました。

Softmax Cross-Entoropy Lossは、pytorchのCrossEntropyLossを用います。Class-Balancedにするときには、weightパラメータに先程計算したクラスごとの重みを渡します。

class SoftmaxCrossEntorpyLoss(nn.Module):
    def __init__(self, weight=None):
        super(SoftmaxCrossEntorpyLoss, self).__init__()
        self.criterion = nn.CrossEntropyLoss(weight=weight, reduction="sum")
    
    def forward(self, out, label):
        loss_sum = self.criterion(out, label)
        return loss_sum/out.shape[0]

Sigmoid Cross-Entoropy Lossは、pytorchのbinary_cross_entropy_with_logitsを用います。ここの部分は論文の記述と公式のコードの実装が違っていたので、公式コードの方に合わせて実装しました。Class-Balancedにするときは、binary_cross_entropy _with_logitsのweightパラメータにクラスごとの重みを渡します。

class SigmoidCrossEntropyLoss(nn.Module):
    def __init__(self, weight=None):
        super(SigmoidCrossEntropyLoss, self).__init__()
        self.weight = weight
    
    def forward(self, out, label):
        weight = None
        if self.weight is not None:
            weight = self.weight[label].view(-1,1).repeat(1, out.shape[1]).cuda()
        one_hot = F.one_hot(label, num_classes=out.shape[1]).to(torch.float)
        loss = F.binary_cross_entropy_with_logits(out, one_hot, weight=weight, reduction='none')
        return loss.sum()/out.shape[0]

Focal Lossについても、論文と公式コードの実装が異なっていたのでコードの方に合わせました。binary_cross_entropy_with_logitsでは、各ラベルの予測確率をpとすると、正解ラベルではp/それ以外だと1-pを小さくするように学習します。そのため、Focal Lossの重みは正解ラベルでは$(1-p)^\gamma$/それ以外では$p^\gamma$となります。下のコードのfocal_weightでは安定した実装のために計算式が複雑化していますが(公式コード参照)、ちゃんと計算すると$(1-p)^\gamma$または$p^\gamma$が得られるはずです。
Class-Balancedにするときは、Sigmoid同様binary_cross_entropy _with_logitsのweightパラメータにクラスごとの重みを渡します。
また学習の過程でtorch.expの出力がinfになってしまう問題が発生したため、log-sum-exp-trickの計算方法を適用しました。

class FocalLoss(nn.Module):
    def __init__(self, gamma, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight
        
    def forward(self, out, label):
        weight = None
        if self.weight is not None:
            weight = self.weight[label].view(-1,1).repeat(1, out.shape[1]).cuda()
        one_hot = F.one_hot(label, num_classes=out.shape[1]).to(torch.float)
        logpt = F.binary_cross_entropy_with_logits(out, one_hot, weight=weight, reduction='none')
        max_val = torch.clamp(-out, min=0) 
        focal_weight = torch.exp(-self.gamma * one_hot * out - self.gamma * (torch.log(torch.exp(1. - max_val) + torch.exp(-out - max_val)) + max_val))
        loss = focal_weight * logpt
        return loss.sum()/out.shape[0]

実験条件

実験条件は論文に倣い、以下のように設定しました。

  • データセット: CIFAR10
  • モデル: ResNet32
  • エポック数: 200
  • バッチサイズ:128
  • 学習率: 0.1
  • 1GPU

学習率は最初の5epochで線形的に0.1まで増加させ、その後160エポック、180エポックで0.01倍に減衰させます。Effective Numberにおける$\beta$は[0.9, 0.99, 0.999, 0.9999]の4種類を試し、一番精度がよかったものを採用します。またFocal Lossの$\gamma$は[0.5, 1.0, 2.0]の3種類について試しました。

実験結果

各分類ロスにおいて、class-balanced処理を入れるとき/入れないときの画像分類精度を以下の表に示します。

分類ロス Softmax Sigmoid Focal ($\gamma$=0.5) Focal ($\gamma$=1.0) Focal ($\gamma$=2.0)
normal 57.2 59.1 57.8 59.4 57.2
class-balanced 61.4 62.1 61.7 60.1 58.2
$\beta$ 0.99 0.9999 0.999 0.99 0.99

class-balancedロスを用いることで、精度が上がっていることが検証できました。

参考文献

Y. Cui, M. Jia, T. -Y. Lin, Y. Song and S. Belongie, "Class-Balanced Loss Based on Effective Number of Samples," 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 9260-9269.

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?