LoginSignup
41
21

More than 5 years have passed since last update.

[PyTorch]不均衡データのを扱うための損失関数

Last updated at Posted at 2018-12-16

不均衡データとは

正例と負例の比率が、1:99のように偏っているデータを不均衡データと呼びます。
全てのデータを負例と予測すれば精度は99%となるわけですが、PrecisionとRecallを考慮すると良くないモデルであることがわかります。

Positive Negative
True 0 10
False 0 990
Precision = \frac{TP}{TP + FP} \\
Recall    = \frac{TP}{TP + FN}

上式より、Recall, Precisionともにゼロとなります。

対策

対策方針としては大きくふたつあります。
- トレーニングデータからのサンプリング方法調整 (アンダーサンプリング・オーバーサンプリング)
- 損失関数の中でクラスの重みを調整

今回は、損失関数の中でクラスの重みを調整する方法を紹介します。

# [0,1]の二値分類タスクを解く想定
# 正例:1の発生頻度が少ないので、重みを付けたい

weights = torch.tensor([1.0, 100.0]).cuda()
cross_entropy_loss = nn.CrossEntropyLoss(weight=weights)

これだけです。簡単ですね。
PyTorchの損失関数に関する公式ドキュメントを見るとほとんどのメソッドにweight引数が設定されています。

不均衡データを扱っていてLossが減っているのにRecallが上がらない!という場合はぜひクラスの重みづけを検討してみてください。

以上です。

41
21
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
41
21