search
LoginSignup
18
Help us understand the problem. What are the problem?

More than 3 years have passed since last update.

posted at

updated at

Organization

[PyTorch]不均衡データのを扱うための損失関数

不均衡データとは

正例と負例の比率が、1:99のように偏っているデータを不均衡データと呼びます。
全てのデータを負例と予測すれば精度は99%となるわけですが、PrecisionとRecallを考慮すると良くないモデルであることがわかります。

Positive Negative
True 0 10
False 0 990
Precision = \frac{TP}{TP + FP} \\
Recall    = \frac{TP}{TP + FN}

上式より、Recall, Precisionともにゼロとなります。

対策

対策方針としては大きくふたつあります。
- トレーニングデータからのサンプリング方法調整 (アンダーサンプリング・オーバーサンプリング)
- 損失関数の中でクラスの重みを調整

今回は、損失関数の中でクラスの重みを調整する方法を紹介します。

# [0,1]の二値分類タスクを解く想定
# 正例:1の発生頻度が少ないので、重みを付けたい

weights = torch.tensor([1.0, 100.0]).cuda()
cross_entropy_loss = nn.CrossEntropyLoss(weight=weights)

これだけです。簡単ですね。
PyTorchの損失関数に関する公式ドキュメントを見るとほとんどのメソッドにweight引数が設定されています。

不均衡データを扱っていてLossが減っているのにRecallが上がらない!という場合はぜひクラスの重みづけを検討してみてください。

以上です。

Register as a new user and use Qiita more conveniently

  1. You can follow users and tags
  2. you can stock useful information
  3. You can make editorial suggestions for articles
What you can do with signing up
18
Help us understand the problem. What are the problem?