0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

物体検出SSD-4 : 実装1 実装全体の流れとL2Normクラスを実装

Posted at

参考文献

  • 新納浩幸『PyTorchによる物体検出』オーム社, 2020年
  • Wei Liu, et al. "SSD: Single Shot MultiBox Detector", ECCV, pp. 21-37, 2016.

前回の内容

  • IoU
  • SSDの損失関数
  • Non Maximum Suppression

前回分を見ていない方はこちらから。

SSDの原論文

原論文はこちらから

今回の内容

  • 実装全体の流れ
  • L2Normalaization
  • 補足 nn.constant_

コードについて

  • 補足以外は基本的にpyファイルでの実装です
  • 別の形式での実装をする際は適宜説明します
  • 冗長なくらいコメントアウトがあります
  • 関数やメゾットの説明を補足で示す時がありますが、ipynbでの実行結果となっています。

実装全体の流れ(この順で説明します)

今回からコードの説明に移ります。コードの説明を全て見終えた時になんとなくでもいいので処理内容が浮かべばいいなと思い書きました。
L2Normクラス

  • L2Normalaizationを行うクラス

PriorBoxクラス

  • アスペクト比を用いてDBoxを生成するクラス

SSDクラス

  • SSDのモデルアーキテクチャを作成

SSDLossクラス

  • SSDの損失関数を実装

データの準備や学習ループ

  • データの準備や加工、学習ループを実装

Detectクラスとvisuable_detections関数

  • Detectクラス: テスト時にNon Maximum Suppression
    を行い、検出結果を出力
  • visuable_detections関数: 検出結果を受け取り、実際に可視化する関数

L2Normクラスの実装

L2Normalaizationを行うクラスです。L2Normalaizationがよく分からない人は


こちらを参考して下さい

以下の順番で説明します。

  • initメゾット
  • reset_parametersメゾット
  • forwardメゾット
  • 全体のコード

initメゾット

initメゾットでは必要な属性などの初期化を行います。詳しい内容等はコメントアウトやdocstringsを参照してください。正規化後に掛ける学習可能なパラメータはnn.Parameterを用いて行います。自分でモデルを実装する時に、これで学習可能なパラメータを定義することは最近増えました。(transformer系のモデルなど)


class L2Norm(nn.Module):
    """L2正規化レイヤー
    Attributes:
        n_channels (int): 入力のチャンネル数
        gamma (float): 重みの初期値
        eps (float): ゼロ除算を防ぐための小さな値(1e-10)
        weight (torch.nn.Parameter): 学習可能なスケーリング重み
    """

    def __init__(self, n_channels=512, scale=20):
        super().__init__()
        self.n_channels = n_channels
        # 正規化後に掛ける学習可能なパラメータ,channel分だけある
        # デフォルトでは初期値scale(20)
        self.gamma = scale 
        # 0で割ることを防ぐためのε
        self.eps = 1e-10 
        # 学習可能なスケーリング重み
        self.weight = nn.Parameter(torch.Tensor(self.n_channels))
        # weightを初期化
        self.reset_parameters()

reset_parametersメゾット

学習可能なパラメータをここで初期化します。nn.init.constant_(self.weight, self.gamma)はself.weight(initメゾットで定義したスケーリング用の学習可能なパラメータ)をself.gamma(initメゾットでself.gamma = scaleで定義されている)で初期化します。

def reset_parameters(self):
        """weightを初期化するメゾット
        """
        nn.init.constant_(self.weight, self.gamma) 

補足にnn.init.constant_の使用例を示しておきます。

forwardメゾット

最後にforwardメゾットの実装です。ここで正規化を実施します。詳しい内容はコメントアウトでたくさん書きましたので適宜参考にしてください。最後にself.weight.view(1, self.n_channels, 1, 1)としているのはスケーリング用のパラメータが1次元のテンソルなので、Xの形状をそろえてチャネルごとにスケーリング用のパラメータを適用するためです。(そもそも、.viewしないとブロードキャスト出来なくてエラーになります。)

 def forward(self, X):
        """
        Args:
            X (torch.Tensor): 入力テンソルの形状 [b, c, h, w]

        Returns:
            torch.Tensor: 正規化およびスケーリングされた出力テンソル
        """
        # channel方向にL2Normを計算 norm : [b, 1, h, w]
        # x.pow(n) : xのn乗
        # x.sum(dim=1, keepdim=True) : 指定した次元(dim=1 :channel方向)に沿って要素の和を計算
        # keepdim=True : 次元を保持する
        # sqrt()でL2Normが算出される
        # epsでゼロ除算を防ぐ
        norm = X.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 
        # 入力をnormで割る(正規化を行う)
        X = torch.div(X, norm) 
        # スケーリングの重みを掛ける
        out = self.weight.view(1, self.n_channels, 1, 1) * X 
        return out

コード全体

最後に実装コード全体を載せておきます

class L2Norm(nn.Module):
    """L2正規化レイヤー
    Attributes:
        n_channels (int): 入力のチャンネル数
        gamma (float): 重みの初期値
        eps (float): ゼロ除算を防ぐための小さな値(1e-10)
        weight (torch.nn.Parameter): 学習可能なスケーリング重み
    """

    def __init__(self, n_channels=512, scale=20):
        super().__init__()
        self.n_channels = n_channels
        # 正規化後に掛ける学習可能なパラメータ,channel分だけある
        # デフォルトでは初期値scale(20)
        self.gamma = scale 
        # 0で割ることを防ぐためのε
        self.eps = 1e-10 
        # 学習可能なスケーリング重み
        self.weight = nn.Parameter(torch.Tensor(self.n_channels))
        # weightを初期化
        self.reset_parameters()

    def reset_parameters(self):
        """weightを初期化するメゾット
        """
        nn.init.constant_(self.weight, self.gamma) 

    def forward(self, X):
        """
        Args:
            X (torch.Tensor): 入力テンソルの形状 [b, c, h, w]

        Returns:
            torch.Tensor: 正規化およびスケーリングされた出力テンソル
        """
        # channel方向にL2Normを計算 norm : [b, 1, h, w]
        # x.pow(n) : xのn乗
        # x.sum(dim=1, keepdim=True) : 指定した次元(dim=1 :channel方向)に沿って要素の和を計算
        # keepdim=True : 次元を保持する
        # sqrt()でL2Normが算出される
        # epsでゼロ除算を防ぐ
        norm = X.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 
        # 入力をnormで割る(正規化を行う)
        X = torch.div(X, norm) 
        # スケーリングの重みを掛ける
        out = self.weight.view(1, self.n_channels, 1, 1) * X 
        return out

補足 nn.init.constant_

補足としてnn.init.constant_で初期化する例と結果を載せておきます。なお、パラメータの形状(3,4)と初期値(100)は特に理由はないです。

# 適当に初期化(ここでは、とりあえず値がランダムに初期化されます)
sample_weight = nn.Parameter(torch.Tensor(3,4))
sample_weight

# 結果
# Parameter containing:
# tensor([[6.1473e+03, 1.6437e-42, 0.0000e+00, 0.0000e+00],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
#         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]], requires_grad=True)
# nn.init.constant_を適用
# 全て100で初期化
nn.init.constant_(sample_weight, 100)

# 結果
# Parameter containing:
# tensor([[100., 100., 100., 100.],
#         [100., 100., 100., 100.],
#         [100., 100., 100., 100.]], requires_grad=True)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?