Help us understand the problem. What is going on with this article?

不均衡データを損失関数で攻略してみる

More than 1 year has passed since last update.

はじめに

 研究の用途においては、データセットによる精度への影響を避けるため、各クラスのデータ数が同じくらいの均衡なデータセットがよく利用されています。しかし、いざ実サービスで学習用のデータを集めようとしても、全てのクラスで同じ数のデータを集めるのは難しい場合があります。そのような不均衡なデータセットに対して、2019/01/16にarxivで発表があった Class-Balanced LossFocal Lossなど、損失関数の工夫だけでどこまで精度を上げられるのか検証してみました。

データ不均衡による精度の劣化

 CIFAR-10データセットを使って学習および検証を行いました。学習データはCIFAR-10の訓練用画像 50,000枚(各class 5,000枚)からautomobileとdeerとfrogのデータを250枚に減らして、意図的に不均衡にしたものを使用しています。検証にはCIFAR-10の検証用画像 10,000枚(各class 1,000枚)を使います。学習時のデータ数に基づいて、automobileとdeerとfrogの minorityグループ、それ以外の majorityグループ、10,000枚全てを使ったallグループの3つに分け、それぞれのグループ毎に検証して、正解率(Accuracy)を計測しました。学習データと検証データの内訳は下記の表のとおりになります。

class 学習データ 検証データ(minority) 検証データ(majority) 検証データ(all)
0 airplane 5,000 0 1,000 1,000
1 automobile 250 1,000 0 1,000
2 bird 5,000 0 1,000 1,000
3 cat 5,000 0 1,000 1,000
4 deer 250 1,000 0 1,000
5 dog 5,000 0 1,000 1,000
6 frog 250 1,000 0 1,000
7 horse 5,000 0 1,000 1,000
8 ship 5,000 0 1,000 1,000
9 truck 5,000 0 1,000 1,000

ネットワークの構造はResNet18で、全ての検証は下記の要件に従って学習を行います。

Epoch BatchSize Optimizer LearningRate
25 128 SGD Momentum
Momentum=0.9
WeightDecay=0.0005
0.1(1〜15 epoch)
0.01(16〜20 epoch)
0.001(21〜25 epoch)

ちなみに、不均衡ではない普通のCIFAR-10に対し、一般的な損失関数 SoftmaxCrossEntropy で学習した結果は次のようになりました。全てのグループでAccuracyが90%を超えています。

均衡データの学習結果

MajorityのAccuracy MinorityのAccuracy AllのAccuracy
90.23 % 94.4 % 91.4 %

次は、これを不均衡データにした場合の結果になります。

不均衡データの学習結果

MajorityのAccuracy MinorityのAccuracy AllのAccuracy
87.93 % 50.4 % 76.5 %

均衡なデータの場合に比べて、Minorityグループの精度が極端に悪く、それにひきずられる形でAllグループの精度も76.5%と悪くなりました。

データ数の逆数でバランシング

 損失関数の工夫として、クラスのデータ数の逆数を重みとして掛けて、Lossへの寄与率をクラスのデータ数に反比例するように調節する手法がよく使われています。クラス数が$C$で、各クラスのラベルが $y\in \{1, 2,...,C\}$であるデータセットについて考えてみます。モデルが出力するSoftmaxの値を$ \boldsymbol{p} =[p_1, p_2,..., p_C]^T$, ただし $p_i \in [0,1] \hspace{3pt} \forall \hspace{3pt} i$ 。学習時のミニバッチ数を$B$ 。ミニバッチ内の正解ラベルを$ \boldsymbol{m} =[m_1, m_2,..., m_B]^T$,ただし $m_i \in \{1, 2,...,C\} $ 。 その場合、通常のSoftmaxCrossEntropyLossは下記のようになります。

\textbf{CE}(\hspace{2pt}\boldsymbol{p}, \boldsymbol{m}) = -\sum_{i=1}^{B}log(p_{m_i}) \tag{1}

 クラス$y$の学習データの数を$n_y$とすると、クラス数の逆数でバランシングした損失関数InverseClassFrequencyLossは下記のように書けます。

\textbf{ICF}(\hspace{2pt}\boldsymbol{p}, \boldsymbol{m}) = -\sum_{i=1}^{B} \frac{1}{n_{m_i}} log(p_{m_i}) \tag{2}

 InverseClassFrequencyLossの検証結果は次のようになりました。(実際の検証では、ミニバッチ単位のLossの出力の大きさを通常のSoftmaxCrossEntropyLossと揃えるために、加重平均をとっています。加重平均をとっても、各クラスの全体に対する寄与率は変わりません。)

クラス数の逆数でバランシングの学習結果

MajorityのAccuracy MinorityのAccuracy AllのAccuracy
74.1 % 74.67 % 73.58 %

 通常のSoftmaxCrossEntropyLossに比べてMinorityグループの精度がだいぶ高くなりました。ただMajorityグループの精度は落ちてしまったため、Allグループの精度は悪くなっています。

Focal Loss

 OneStageのObject Detectionの学習において、背景(EasyNegative)がほとんどであり、クラスが不均衡状態になっているという仮説のもと、それを自動的にコスト調節してくれる損失関数として、Facebook AI Researchが提案した手法1です。ICCV2017で発表されStudent Best Paperに選ばれています。

Focal Loss

 上の図が示すように、通常の損失関数に$(1-p_t)^\gamma$が係数として乗じられています。したがって、$p_t$が小さく0に近い場合(大きく間違っている場合)は、この係数は1に近いため通常の損失関数の値に近くなります。逆に$p_t$が大きく1に近い場合(ほとんど間違っていない場合)は、係数は0に近いため損失関数の値はLossとしてほとんど計上されなくなります。このように損失関数の係数部分が自動的にEasyNegativeExampleをDownWeightし、結果としてHardNegativeに比重が置かれるように機能します。

このFocalLossを多クラス分類のSoftmaxCrossEntropyに適応すると下記のように書けます。

\textbf{FL}(\hspace{2pt}\boldsymbol{p}, \boldsymbol{m}) = -\sum_{i=1}^{B} (1-p_{m_i})^\gamma log(p_{m_i}) \tag{3}

$\gamma\in \{1.0, 1.5, 2.0\}$ の3パターンで検証してみました。

$\gamma$ MajorityのAccuracy MinorityのAccuracy AllのAccuracy
1.0 88.47 % 55.03 % 78.33 %
1.5 87.91 % 53.56 % 77.61 %
2.0 86.37 % 49.53 % 75.32 %

 InverseClassFrequencyLossに比べると、Minorityグループの精度は全体的に落ちていますが、Majorityグループの精度は全て高くなっています。$\gamma=1.0$のAllグループの精度は、これまでの不均衡データ検証の中では最も良い結果となりました。下図は $\gamma=1.0$のグラフになります。

FocalLossの学習結果

本質的なクラスのデータ数でバランシング

 先述した InverseClassFrequencyLossの検証結果からもわかる通り、クラスのデータ数の逆数を重みとして掛ける事で、不均衡データに対するバランシングの効果が期待できます。この考え方をさらに深掘りした Class-Balanced Loss Based on Effective Number of Samples2 という興味深い論文がarxivに上がっていたので紹介します。この論文では、InverseClassFrequencyLossで利用しているクラスのデータ数というのは、実は本質的なデータ数ではないという考え方を元に理論が構築されています。どういうことかと言うと、例えば、class0のデータが800枚、class1のデータが500枚からなるデータセットがあるとします。それぞれのclassをData Augmentationして、class0もclass1も1,000枚に増やした場合、表面上のデータ数はどちらも1,000枚と同じになりますが、本質的なデータ数という意味ではこの2つは違います。このように実データ数と本質的なデータ数には差があるという事です。
 もう一つ。ある画像から特徴ベクトルを抽出し、その特徴ベクトルを分類器にかける事で、画像のクラス分類するといったケースを考えてみます。この場合、分類に利用している特徴ベクトルは、ベクトルの次元数が固定です。したがって、そのベクトルの表現パターン数以上のデータ表現はできません。つまり、どんなに多くの画像を集めたとしても、この分類器にとってのデータ数は、ベクトルの表現量を超える事はないという事です。これもまた、実データ数と本質的なデータ数には差があるという例となります。
 この差は、タスクの種類やデータセットの特性、およびその周辺情報によっても変わってくるため、拡大や回転したものを同じデータ扱いにするのか、同じオブジェクトが写っていれば同じデータ扱いにするのか、これらは一概には決められません。しかし、ここで重要なのは、どのようなデータセットにおいても、少なからず実データ数と本質的なデータ数には差があるという事です。
 したがって、データセットのデータを、全てユニークなデータとしてカウントした値を不均衡データのバランシングに使うのではなく、本質的なデータ数に配慮した値を用いてデータバランシングを行うべき、というのが本論文の主張になります。

理論

 本質的なデータ数が具体的にどのような値になるのか説明します。本質的なデータ数の最大値を$N$とおきます。例えば、車の画像が無限にあったとして、全ての車の画像を本質的なデータ空間に写像した場合、$N$パターンとなります。実データ数を$n$、その本質的なデータ数を$E_n$とすると、下記のように書けます。

E_n = (1 - \beta^n)/(1 - \beta) \tag{4} \\
ただし、 \beta = (N - 1)/N

と、いきなり式が出てきましたが、数学的帰納法を使ってこれを証明します。
 まず、$E_1$は重なりようがないので、$E_1 = 1$となります。次に、$E_{n-1}$で成立する場合に$E_n$でも成立するかを考える必要がありますが、その前に下の図を見てください。

帰納法での証明の図

 これは、$n$個目のデータをサンプリングした際の、本質的なデータ空間の変化の様子を示した概念図になります。破線の領域は本質的なデータ空間を表しています。したがって、この領域のデータ数は$N$です。灰色の領域は$n-1$個の実データが写像された本質的なデータ空間の領域になります。したがって、この領域のデータ数は$E_{n-1}$です。次に、$n$個目のデータを本質的なデータ空間に写像しようとした場合、$E_{n-1}/N$の確率で灰色の領域に入ります。
したがって、

p = \frac{E_{n-1}}{N}

とすると、$E_n$の期待値は下記のようになります。

E_n = pE_{n-1} + (1-p)(E_{n-1} + 1) = 1 + \frac{N - 1}{N}E_{n-1}  \tag{5}

ここで、式$(4)$が$E_{n-1}$では成立すると仮定すると、

E_{n-1} =(1 - \beta^{n-1})/(1 - \beta) \tag{6}

と書け、式$(5)$の$E_{n-1}$に代入すると、

E_{n} =1 + \frac{N - 1}{N}E_{n-1} = 1 + \beta \frac{1 - \beta^{n-1}}{1 - \beta} = \frac{1 - \beta + \beta - \beta^{n}}{1 - \beta} = \frac{1 - \beta^{n}}{1 - \beta} \tag{7}

となります。$E_n$の場合も式$(4)$を満たすことがわかったので、証明はこれで完了となります。

式$(4)$にあるように、$\beta = (N - 1)/N$とおいており、$N\geqq1$なので、$\beta$の取り得る値の範囲は$0 \leqq \beta < 1$です。したがって、$n \rightarrow \infty$ の場合 $\beta^{n} \rightarrow 0$ となるので、

\lim_{n \to \infty}E_n = \lim_{n \to \infty}\frac{1 - \beta^{n}}{1 - \beta} = \frac{1}{1 - \beta} = N

となります。これは無限のデータを本質的なデータ空間に写像した場合、そのデータ数は$N$に収束する事を表しています。

次に、$\beta = 0$ $(N = 1)$の場合をみてみると、

E_{n} = (1 - 0^n) / (1 - 0) = 1

となり、$\beta = 0$では、全てのデータは本質的には全て同じものである事がわかります。

また、$\beta \rightarrow 1$ $(N \rightarrow \infty)$の場合は、

f(\beta) = 1 - \beta^n, \quad g(\beta) = 1 - \beta

とすると、

f'(\beta) = -n \beta^{n-1}, \quad g'(\beta) = -1

ロピタルの定理より、

\lim_{\beta \to 1} E_n = \lim_{\beta \to 1} \frac{f(\beta)}{g(\beta)} = \lim_{\beta \to 1} \frac{f'(\beta)}{g'(\beta)} = \lim_{\beta \to 1} \frac{(-n \beta^{n-1})}{(-1)} = n

となります。したがって $\beta \rightarrow 1$では、全てのデータは本質的にもユニークであり、重複がない状態である事がわかります。

実際のデータセットの場合

クラス$y$の学習データの実データ数を$n_y$とすると、クラス$y$の本質的なデータ数 $E_{n_y}$は、

E_{n_y} = (1 - \beta_y^{n_y})/(1 - \beta_y) \\
ただし、 \beta_y = (N_y - 1)/N_y

となります。しかし、クラス$y$についての周辺情報がないため、$N_y$(クラスyを本質的に表現できるデータ数の最大値)を正確に求めるのは極めて難しい作業です。そこで、各クラスの$N_y$は、データセットの画素数や学習するネットワーク構造など、すべてのクラスに共通するもののみに依存すると仮定し、$N_y = N$, $\beta_y = \beta$として、全クラスで同じ値を使います。すると、$E_{n_y}$は下記のようになります。

E_{n_y} = (1 - \beta^{n_y})/(1 - \beta) \\
ただし、 \beta = (N - 1)/N


下図は $\beta$の値によって、$n_y$と$E_{n_y}$の関係がどのように変化するのかを可視化したグラフになります。横軸が$n_y$で縦軸が$E_{n_y}$です。

nとEの関係の図

Class Balanced Loss

InverseClassFrequencyLossでは、クラスのデータ数の逆数を重みとしてバランシングしていましたが、このクラスのデータ数の部分を本質的なクラスのデータ数 $E_{n_y}$に置き換える事で、ClassBalancedLoss になります。

SoftmaxCrossEntropyLoss を ClassBalanced 化すると下記のようになります。

\textbf{CB}_{softmax}(\hspace{2pt}\boldsymbol{p}, \boldsymbol{m}) = -\sum_{i=1}^{B} \frac{1}{E_{n_{m_i}}} log(p_{m_i}) = -\sum_{i=1}^{B} \frac{1 - \beta}{1 - \beta^{n_{m_i}}} log(p_{m_i})

FocalLoss を ClassBalanced 化すると下記のようになります。

\textbf{CB}_{focal}(\hspace{2pt}\boldsymbol{p}, \boldsymbol{m}) = -\frac{1}{E_{n_{m_i}}} \sum_{i=1}^{B} (1-p_{m_i})^\gamma log(p_{m_i}) = -\frac{1 - \beta}{1 - \beta^{n_{m_i}}} \sum_{i=1}^{B} (1-p_{m_i})^\gamma log(p_{m_i})

$\beta$はハイパーパラメータで、値の範囲は$0 \leqq \beta < 1$になります。$\beta = 0$は重みを全く掛けていない状態に相当し、$\beta \rightarrow 1$はInverseClassFrequencyLossと同じ状態に相当します。

検証

ClassBalancedSoftmaxは、$\beta\in \{0.9, 0.99, 0.999, 0.9999\}$ の4パターンで検証してみました。

$\beta$ MajorityのAccuracy MinorityのAccuracy AllのAccuracy
0.9 88.66 % 49.23 % 76.83 %
0.99 89.45 % 58.67 % 80.22 %
0.999 86.76 % 66.37 % 80.64 %
0.9999 80.94 % 75.23 % 79.23 %

$\beta=0.999$の時に、AllのAccuracyは $80.64$% となり、今までのどの検証よりも良い結果となっています。下記がそのグラフです。

ClassBalancedSoftmax beta0.999

ClassBalancedFocalLossは、$\gamma=1.0$固定として、$\beta\in \{0.9, 0.99, 0.999, 0.9999\}$ の4パターンで検証してみました。

$\beta$ $\gamma$ MajorityのAccuracy MinorityのAccuracy AllのAccuracy
0.9 1.0 88.67 % 53.70 % 78.18 %
0.99 1.0 89.42 % 60.63 % 80.60 %
0.999 1.0 85.11 % 64.70 % 78.99 %
0.9999 1.0 72.40 % 62.67 % 69.48 %

$\beta=0.99$の時に、AllのAccuracyは $80.60$% となり、ClassBalancedSoftmaxの最高値とほぼ同等の精度となりました。下記がグラフになります。

ClassBalancedFocalLoss beta0.99

まとめ

 ClassBalancedな損失関数を使えば、SoftmaxCrossEntropyやFocalLoss、どちらの場合も精度が向上する事が確認できました。また、理論部分に関しても証明付きで綺麗にまとまっており、とても読みやすい論文となっています。ただ、本質的なデータ空間という、独特な抽象概念を土台に理論が組み立てられているため、$\beta$ の値をあらかじめ正確に決めることは、原理的に不可能です。したがって、実際のデータに対してこの損失関数を適用する場合は、いくらか学習を試しながら、ハイパーパラメータ$\beta$の値を探査し、最適な値を決める必要があります。


  1. Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, and Piotr Dollár. "Focal Loss for Dense Object Detection" In ICCV, 2017. 

  2. Yin Cui, Menglin Jia, Tsung-Yi Lin, Yang Song, and Serge Belongie. "Class-Balanced Loss Based on Effective Number of Samples" In CVPR, 2019. 

tancoro
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした