はじめに
クラス間の出現頻度にばらつきがあるデータセットでは、高頻度クラスが台頭し低頻度クラスの予測精度が相対的に下がる問題が生まれます。この問題に対して損失関数への工夫という切り口から改善を図ったCVPR2019の論文、「Class-Balanced Loss Based on Effective Number of Samples」について解説します。
1. 論文解説
概要
不均衡データの分類問題に対しては一般に、データのresampling・損失関数のreweightingという二つの手法がとられます。前者のresamplingは少ないデータを水増しするため過学習を起こしやすいという問題があり、近年ではreweightingの方に焦点が当てられることが多いです。
reweightingは、「出現数が全然違うクラスを均等に学習しても上手くいかないから、クラスごとにいい感じの重みをつけて結果的にバランスが取れるように学習しよう」という試みです。従来はこの「いい感じの重み」には各クラスのサンプル数(=データセットでの出現数)の逆数が使われていたのですが、データセットが大規模になると効果が出ないという問題がありました。
そこでこの論文では
- 「サンプル数」の定義を再考して新たな損失関数の重みを提案
- 提案手法を用いて不均衡データセットにおける分類精度を改善
しました。次章以降で手法の中身について詳しく見ていきます。
提案手法
Effective Number
最初にこの論文のキーワードである「Effective Number」の考え方について説明します。上でも述べたように、既存手法では各クラスの出現数の逆数が重みとして用いられていました。しかしサンプルの中には、特徴量空間で他と類似しているサンプルやAugmentationで複製したサンプルがあり、これらは実質的には「被り」であり一つの独立したデータとしてカウントするのはどうなのか?と言うのがこの論文の注目点です。そこで、この「被り」サンプルは無視しユニークなサンプルだけを数えて「サンプル数」とみなすことにし、これをEffective Numberと読んでいます。
ではこのEffective Numberはどう計算すればいいのでしょうか。結論から行くと、出現数$n$のクラスのEffective Numberは以下の$E_n$のように定義されます。
$$E_n = \frac{\left(1-\beta^n\right)}{1-\beta},\tag{1}$$$$
\beta = \frac{\left(N-1\right)}{N}\tag{2}$$
ここで$n$は「被り」も含めた文字通りの出現数、$N$はEffective Numberの上限を指します。直感的には$\beta$が$E_n$の増加速度をコントロールするハイパーパラメータで、$\beta=0$(N=1を意味:全サンプル被り)のとき$E_n=1$、$\beta\rightarrow1$(N→∞を意味:全サンプル独立)のとき$E_n\rightarrow n$となります。
Class-Balanced Loss
分類ベクトルを$\mathbf{p} \in \mathbb{R}^C$、正解ラベルを$y \in \lbrace 1,2, \dots, C\rbrace$として、ロス関数を$\mathcal{L}(\mathbf{p}, y)$とします。このとき提案手法でのClass-Balancedロスは、Effective Numberを用いて以下のように計算できます。
$$\mathrm{CB}(\mathbf{p}, y) = \frac{1}{E_{n_y}}\mathcal{L}(\mathbf{p}, y) = \frac{1-\beta}{1-\beta^{n_y}}\mathcal{L}(\mathbf{p}, y)$$
式を見てもわかるように、この手法はロス関数やモデル構造に非依存に適用が可能です。本論文では従来のロス関数であるSoftmax Cross-Entoropy Loss、Sigmoid Cross-Entoropy Loss、Focal LossにそれぞれこのClass-Blanced処理を加えて実験を行なっています。
2. 実験
論文の手法を検証するため、本記事では実際にClass-Balanced LossをPytorchで実装しCIFAR10におけるクラス分類の精度評価を行いました。
Class-Balanced Lossの実装
コードベース全体はこちらで公開しております。(ちなみに論文の公式コードはこちらになります。)
本記事では、論文の肝であるClass-Balanced Lossについての実装方法について説明します。
Effective Numberを使った重みは以下のget_loss_func
関数のように計算します。ここでnum_per_cls
はクラスごとの出現数を表し、cfg.betaは$\beta$を指します。
全体の損失関数CombinedLoss
は分類ロス(Softmax Cross-Entoropy LossまたはSigmoid Cross-Entoropy LossまたはFocal Loss)に正則化項を加えたものです。
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_loss_func(cfg, num_per_cls):
weight = None
if cfg.beta:
weight = (1. - cfg.beta) / (1. - torch.pow(cfg.beta, torch.tensor(num_per_cls)))
weight = weight / torch.sum(weight) * len(num_per_cls)
return CombinedLoss(cfg, weight)
class CombinedLoss(nn.Module):
def __init__(self, cfg, weight=None):
super(CombinedLoss, self).__init__()
self.weight_decay = cfg.weight_decay
self.loss_type = cfg.loss_type
if cfg.loss_type == "softmax":
self.cls_loss = SoftmaxCrossEntorpyLoss(weight)
elif cfg.loss_type == "sigmoid":
self.cls_loss = SigmoidCrossEntropyLoss(weight)
else:
self.cls_loss = FocalLoss(cfg.gamma, weight)
self.cls_loss.cuda()
def forward(self, out, label, named_parameters):
loss_items = {}
loss_items["classification"] = self.cls_loss(out, label)
loss_items["regularization"] = self.weight_decay * sum([torch.norm(param) for (name, param) in named_parameters if self.loss_type == "softmax" or "linear.bias" not in name])
return loss_items
分類ロスはSoftmax Cross-Entoropy Loss、Sigmoid Cross-Entoropy Loss、Focal Lossの三種類を実装しました。
Softmax Cross-Entoropy Lossは、pytorchのCrossEntropyLoss
を用います。Class-Balancedにするときには、weightパラメータに先程計算したクラスごとの重みを渡します。
class SoftmaxCrossEntorpyLoss(nn.Module):
def __init__(self, weight=None):
super(SoftmaxCrossEntorpyLoss, self).__init__()
self.criterion = nn.CrossEntropyLoss(weight=weight, reduction="sum")
def forward(self, out, label):
loss_sum = self.criterion(out, label)
return loss_sum/out.shape[0]
Sigmoid Cross-Entoropy Lossは、pytorchのbinary_cross_entropy_with_logits
を用います。ここの部分は論文の記述と公式のコードの実装が違っていたので、公式コードの方に合わせて実装しました。Class-Balancedにするときは、binary_cross_entropy _with_logits
のweightパラメータにクラスごとの重みを渡します。
class SigmoidCrossEntropyLoss(nn.Module):
def __init__(self, weight=None):
super(SigmoidCrossEntropyLoss, self).__init__()
self.weight = weight
def forward(self, out, label):
weight = None
if self.weight is not None:
weight = self.weight[label].view(-1,1).repeat(1, out.shape[1]).cuda()
one_hot = F.one_hot(label, num_classes=out.shape[1]).to(torch.float)
loss = F.binary_cross_entropy_with_logits(out, one_hot, weight=weight, reduction='none')
return loss.sum()/out.shape[0]
Focal Lossについても、論文と公式コードの実装が異なっていたのでコードの方に合わせました。binary_cross_entropy_with_logits
では、各ラベルの予測確率をpとすると、正解ラベルではp/それ以外だと1-pを小さくするように学習します。そのため、Focal Lossの重みは正解ラベルでは$(1-p)^\gamma$/それ以外では$p^\gamma$となります。下のコードのfocal_weight
では安定した実装のために計算式が複雑化していますが(公式コード参照)、ちゃんと計算すると$(1-p)^\gamma$または$p^\gamma$が得られるはずです。
Class-Balancedにするときは、Sigmoid同様binary_cross_entropy _with_logits
のweightパラメータにクラスごとの重みを渡します。
また学習の過程でtorch.expの出力がinfになってしまう問題が発生したため、log-sum-exp-trickの計算方法を適用しました。
class FocalLoss(nn.Module):
def __init__(self, gamma, weight=None):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
def forward(self, out, label):
weight = None
if self.weight is not None:
weight = self.weight[label].view(-1,1).repeat(1, out.shape[1]).cuda()
one_hot = F.one_hot(label, num_classes=out.shape[1]).to(torch.float)
logpt = F.binary_cross_entropy_with_logits(out, one_hot, weight=weight, reduction='none')
max_val = torch.clamp(-out, min=0)
focal_weight = torch.exp(-self.gamma * one_hot * out - self.gamma * (torch.log(torch.exp(1. - max_val) + torch.exp(-out - max_val)) + max_val))
loss = focal_weight * logpt
return loss.sum()/out.shape[0]
実験条件
実験条件は論文に倣い、以下のように設定しました。
- データセット: CIFAR10
- モデル: ResNet32
- エポック数: 200
- バッチサイズ:128
- 学習率: 0.1
- 1GPU
学習率は最初の5epochで線形的に0.1まで増加させ、その後160エポック、180エポックで0.01倍に減衰させます。Effective Numberにおける$\beta$は[0.9, 0.99, 0.999, 0.9999]の4種類を試し、一番精度がよかったものを採用します。またFocal Lossの$\gamma$は[0.5, 1.0, 2.0]の3種類について試しました。
実験結果
各分類ロスにおいて、class-balanced処理を入れるとき/入れないときの画像分類精度を以下の表に示します。
分類ロス | Softmax | Sigmoid | Focal ($\gamma$=0.5) | Focal ($\gamma$=1.0) | Focal ($\gamma$=2.0) |
---|---|---|---|---|---|
normal | 57.2 | 59.1 | 57.8 | 59.4 | 57.2 |
class-balanced | 61.4 | 62.1 | 61.7 | 60.1 | 58.2 |
$\beta$ | 0.99 | 0.9999 | 0.999 | 0.99 | 0.99 |
class-balancedロスを用いることで、精度が上がっていることが検証できました。
参考文献
Y. Cui, M. Jia, T. -Y. Lin, Y. Song and S. Belongie, "Class-Balanced Loss Based on Effective Number of Samples," 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 9260-9269.