0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【SSDモデル】Hard Negative Mining について

Last updated at Posted at 2023-02-09

物体検出の SSD モデルについて勉強中ですが、そこに出てくる Hard Negative Mining の実装アルゴリズムが難解で面白かったのでメモっておきます。
基本的に「物体検出とGAN、オートエンコーダー、画像処理入門 PyTorch/TensorFlow2による発展的・実装ディープラーニング」を読んでいますが、Hard Negative Mining についてはわかりにくかったので 、「つくりながら学ぶ!PyTorchによる発展ディープラーニング」のgitサイトを読んで理解できました。

Hard Negative Mining とは、大まかに言えば損失値の配列 loss_c から損失値の大きい順に上位の部分を抜き出すものです。以下に簡単化したプログラムを示します。本物のloss_cは(バッチサイズ、8732)の2階テンソルですが、ここでは1階テンソルで、バッチの次元は無視しています。しかしアルゴリズムの本質は同じです。sortを2度取っているところが難解の原因で、これが何を行っているかがこの記事のポイントです。

import numpy as np
import torch

x_np = np.array([9, 1, 8, 2, 3])
loss_c = torch.tensor(x_np)                    # loss_c : 損失の配列

_, loss_idx = loss_c.sort(0, descending=True)  # loss_c を降順に並べる
print(loss_idx)                                # loss_idx = tensor([0, 2, 4, 3, 1])

_, idx_rank = loss_idx.sort(0)                 # loss_idx を昇順に並べる
print(idx_rank)                                # idx_rank = tensor([0, 4, 1, 3, 2])

ここでは以下が成り立つのがポイントです。

loss_c を降順で並べた時に、loss_c[i] は idx_rank[i] 番目になる。 ここで i は任意のindex。

つまり idx_rank には loss_c の降順の順番が入っています。面白いですよね。
これは以下のような意味です。

idx_rank[0] = 0   ==> loss_c[0]=9 は 0 番目
idx_rank[1] = 4   ==> loss_c[1]=1 は 4 番目
idx_rank[2] = 1   ==> loss_c[2]=8 は 1 番目
idx_rank[3] = 3   ==> loss_c[3]=2 は 3 番目
idx_rank[4] = 2   ==> loss_c[4]=3 は 2 番目

証明は以下のようになります。

loss_c = [c0, c1, c2, ...., cn] のようなテンソル配列とします。

 _, idx_rank = loss_idx.sort(0)                 # loss_idx を昇順に並べる

上の idx_rank の定義から以下が言える。
x = idx_rank[i] とすると、x は loss_idx の i 番目に小さい要素の index です。
そもそも loss_idx の要素は index なので 0,1,2,3,..,n の値となり、
i 番目に小さい要素は i となります。つまり loss_idx[x] = i です。

idx_rank[i]=x
 0             i
|-----------------------------------------------------------
|             |x|
|-----------------------------------------------------------

loss_idx[x]=i
 0                    x
|-----------------------------------------------------------
|                    |i|
|-----------------------------------------------------------


_, loss_idx = loss_c.sort(0, descending=True)  # loss_c を降順に並べる

上の lass_idx の定義から以下が言える。
loss_idx[x] = i の意味するところは、loss_c を降順に並べたときに、
loss_c[i] が x 番目、つまり idx_rank[i] 番目、であるということである。

loss_c[i]=ci
 0            i
|-----------------------------------------------------------
|            |ci|
|-----------------------------------------------------------
              ↑ idx_rank[i] 番目に大きい要素

idx_rank で Hard Negative Mining の mask をつくる

num_neg をある閾値とする。
loss_c[i] は idx_rank[i] 番目であるので、idx_rank で 以下のように mask を作ることで、
num_neg 番目以下の loss_c の要素を抽出できる。

neg_mask = idx_rank < (num_neg).expand_as(idx_rank)

実際に最初のプログラムの続きとして以下を実行すると望むものがえられます。

num_neg = torch.tensor(3)
neg_mask = idx_rank < (num_neg).expand_as(idx_rank)
print(neg_mask)      # tensor([ True, False,  True, False,  True])

つまり neg_mask で loss_c の 0 番目と 2 番, 4 番の要素を抽出することができます。
これは降順のトップ3ですので、期待にあったものです。

今回は以上です。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?