Edited at

[PyTorch]不均衡データのを扱うための損失関数


不均衡データとは

正例と負例の比率が、1:99のように偏っているデータを不均衡データと呼びます。

全てのデータを負例と予測すれば精度は99%となるわけですが、PrecisionとRecallを考慮すると良くないモデルであることがわかります。

Positive
Negative

True
0
10

False
0
990

Precision = \frac{TP}{TP + FP} \\

Recall = \frac{TP}{TP + FN}

上式より、Recall, Precisionともにゼロとなります。


対策

対策方針としては大きくふたつあります。

- トレーニングデータからのサンプリング方法調整 (アンダーサンプリング・オーバーサンプリング)

- 損失関数の中でクラスの重みを調整

今回は、損失関数の中でクラスの重みを調整する方法を紹介します。

# [0,1]の二値分類タスクを解く想定

# 正例:1の発生頻度が少ないので、重みを付けたい

weights = torch.tensor([1.0, 100.0]).cuda()
cross_entropy_loss = nn.CrossEntropyLoss(weight=weights)

これだけです。簡単ですね。

PyTorchの損失関数に関する公式ドキュメントを見るとほとんどのメソッドにweight引数が設定されています。

不均衡データを扱っていてLossが減っているのにRecallが上がらない!という場合はぜひクラスの重みづけを検討してみてください。

以上です。