0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

実装編 ― PyTorch で試すメトリック学習手法5選

Last updated at Posted at 2025-07-26

実装編 ― PyTorchで試す代表的計量学習手法5選

第1回「基礎編」で理論を解説した後は,実際のコードを動かして結果を確かめてみましょう.本記事では簡単なコード解説と例の提示,実装結果の紹介を行います.


1. リポジトリ構成 & セットアップ

以下のコマンドでGithubから,リポジトリをクローンする.

git clone https://github.com/SyunkiTakase/Metric_Learning_Method.git
cd Metric_Learning_Method

リポジトリの構成は以下のようになっている.

Metric_Learning_Method/
├── Config/
│   ├── SiameseNetwork.py
│   ├── TripletLoss.py
│   ├── ArcFace.py
│   ├── CosFace.py
│   └── SphereFace.py
├── base_model/
│   └── Xception.py
├── fig/                 ← 図解
├── method_config.py     ← 手法ごとの組み立てファクトリ
├── metric_model.py      ← Encoder+Head 動的読み込み
├── metric_loss.py       ← Contrastive/Triplet/ArcFace/CosFace/SphereFace
├── trainer.py           ← 学習・検証ループ + t-SNE フック
├── train_metric.py      ← エントリポイント
├── utils.py             ← load_config / save_model / save_log / save_featspace
└── README.md

2. コアモジュール解説

2.1 metric_model.py

バックボーン(ResNet/VGG/Xception)を動的にロードして,特徴量(embeddings) → 出力(logits)を返すモジュール.

class MetricModel(nn.Module):
    """
    メトリック学習モデルの定義

    Parameters
    ----------
    method : str
        使用するメトリック学習手法の名前
    arch : str
        使用するアーキテクチャの名前
    num_dim : int
        特徴量の次元数
    num_classes : int
        分類クラス数
    Returns
    -------
    None
    """
    def __init__(self, method='SiameseNetwork', arch='ResNet18', num_dim=512, num_classes=10):
        super(MetricModel, self).__init__()
        self.method = method
        self.arch = arch
        self.num_dim = num_dim
        self.num_classes = num_classes
        
        self.selected_arch(arch=self.arch)
        last_layers = ['fc', 'classifier', 'heads'] # 最終層の名前リスト

        for layer in last_layers: # 最終層の名前を順に確認
            if hasattr(self.encoder, layer):
                # 最終層がSequentialの場合
                if isinstance(getattr(self.encoder, layer), nn.Sequential):
                    self.out_enc_dim = getattr(self.encoder, layer)[0].in_features
                else:
                    self.out_enc_dim = getattr(self.encoder, layer).in_features
                break

        for layer in last_layers: # 最終層の名前を順に確認
            if hasattr(self.encoder, layer):
                # 最終層がSequentialの場合
                if isinstance(getattr(self.encoder, layer), nn.Sequential):
                    setattr(self.encoder, layer, nn.Sequential(nn.Identity()))
                else:
                    setattr(self.encoder, layer, nn.Identity())
                break

        if self.method == 'ArcFace' or self.method == 'CosFace' or self.method == 'SphereFace': # ArcFace/CosFace/SphereFaceを使用する場合
            self.classifier = nn.Linear(self.out_enc_dim, self.num_dim) # 特徴量の次元をnum_dimに変更
        else: # その他の手法を使用する場合
            self.classifier = nn.Linear(self.out_enc_dim, self.num_classes) # 分類クラス数に変更

    def forward(self, x):
        """
        順伝播の定義

        Parameters
        ----------
        x : torch.Tensor
            入力画像のテンソル
        Returns
        -------
        feat : torch.Tensor
            特徴量
        y : torch.Tensor
            分類結果
        """
        feat = self.encoder(x)
        y = self.classifier(feat)

        return feat, y
    
    def selected_arch(self, arch):
        """
        使用するアーキテクチャを選択する

        Parameters
        ----------
        arch : str
            使用するアーキテクチャの名前
        Returns
        -------
        None
        """
        arch_map = {
            'ResNet18': resnet18,
            'ResNet34': resnet34,
            'ResNet50': resnet50,
            'ResNet152': resnet152,
            'VGG11': vgg11,
            'VGG13': vgg13,
            'VGG16': vgg16,
            'VGG19': vgg19,
        }
        
        if arch in arch_map:
            self.encoder = arch_map[arch](weights='IMAGENET1K_V1')

        else:
            module_name = f'base_model.{arch}'
            try:
                module = importlib.import_module(module_name)
            except ImportError as e:
                raise ValueError(f'Unsupported architecture: {arch} (module {module_name} not found)') from e

            try:
                model_cls = getattr(module, arch)
            except AttributeError as e:
                raise ValueError(f'Module {module_name} does not define class {arch}') from e

            self.encoder = model_cls()

2.2 metric_loss.py

Contrastive, Triplet の距離系と,ArcFace/CosFace/SphereFace のマージン付き分類ヘッドを実装.

# Contrastive Loss
class ContrastiveLoss(nn.Module):
    """
    Contrastive Lossの定義

    Parameters
    ----------
    margin : float
        Contrastive Lossのマージン
    Returns
    -------
    loss : torch.Tensor
        Contrastive Lossの値
    """
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, embeddings, labels):
        """
        Contrastive Lossの計算
        
        Parameters
        ----------
        embeddings : torch.Tensor
            入力特徴量の埋め込みベクトル
        labels : torch.Tensor
            各埋め込みに対応するラベル
        Returns
        -------
        loss : torch.Tensor
            Contrastive Lossの値
        """
        pairwise_dist = torch.cdist(embeddings, embeddings, p=2) # 特徴量間のペアワイズ距離の計算

        positive_mask = (labels.unsqueeze(1) == labels.unsqueeze(0)).float() # 同じラベルのマスク(ポジティブ)
        negative_mask = (labels.unsqueeze(1) != labels.unsqueeze(0)).float() # 異なるラベルのマスク(ネガティブ)

        positive_dist = pairwise_dist * positive_mask
        negative_dist = pairwise_dist * negative_mask

        # 各ペアの数をカウント
        num_pos_pairs = positive_mask.sum().item()
        num_neg_pairs = negative_mask.sum().item()

        if num_pos_pairs == 0 or num_neg_pairs == 0:
            return torch.tensor(0.0)

        # Contrastive Loss
        loss = (
            (positive_dist.pow(2) / 2).mean() +  
            (F.relu(self.margin - negative_dist + 1e-9).pow(2) / 2).mean() 
        )

        return loss

# Triplet Loss
class TripletLoss(nn.Module):
    """
    Triplet Lossの定義

    Parameters
    ----------
    margin : float
        Triplet Lossのマージン
    hard_triplets : bool
        Hard Tripletを使用するかどうか
    Returns
    -------
    loss : torch.Tensor
        Triplet Lossの値
    """
    def __init__(self, margin=1.0, hard_triplets=False):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.use_hard_triplets = hard_triplets

    def forward(self, embeddings, labels):
        """
        Triplet Lossの計算
        
        Parameters
        ----------
        embeddings : torch.Tensor
            入力特徴量の埋め込みベクトル
        labels : torch.Tensor
            各埋め込みに対応するラベル
        Returns
        -------
        loss : torch.Tensor
            Triplet Lossの値
        """
        pairwise_dist = torch.cdist(embeddings, embeddings, p=2) # 特徴量間のペアワイズ距離の計算

        positive_mask = (labels.unsqueeze(1) == labels.unsqueeze(0)).float() # ポジティブペアのマスク
        negative_mask = (labels.unsqueeze(1) != labels.unsqueeze(0)).float() # ネガティブペアのマスク
        # print('Pos Mask:', positive_mask)
        # print('Neg Mask:', negative_mask)

        if self.use_hard_triplets: # Hard PositiveおよびHard Negativeを選択
            positive_dist = pairwise_dist * positive_mask # ポジティブペアの距離
            positive_dist = positive_dist + (1 - positive_mask) * -1e6
            hardest_positive_dist, _ = positive_dist.max(dim=1) # 各アンカーに対する最も遠いポジティブ

            negative_dist = pairwise_dist + (1 - negative_mask) * 1e6 # 無効なネガティブに大きな値を設定
            hardest_negative_dist, _ = negative_dist.min(dim=1) # 各アンカーに対する最も近いネガティブ

            # Triplet Loss
            loss = torch.relu(hardest_positive_dist - hardest_negative_dist + self.margin)
            return loss.mean()

        else: # 全ペアを考慮
            positive_dist = pairwise_dist * positive_mask # ポジティブペアの距離
            # print('Pos Dist:', positive_dist)
            negative_dist = pairwise_dist * negative_mask # ネガティブペアの距離
            # print('Neg Dist:', negative_dist)

            # Triplet Loss
            triplet_loss = positive_dist.unsqueeze(2) - negative_dist.unsqueeze(1) + self.margin 
            triplet_loss = torch.relu(triplet_loss) # マージンに基づくReLU適用

            valid_triplets = positive_mask.unsqueeze(2) * negative_mask.unsqueeze(1) # 有効なTripletのマスク
            triplet_loss = triplet_loss * valid_triplets
            num_valid_triplets = valid_triplets.sum() + 1e-16 # 有効ペア数(ゼロ除算を防ぐ)
            loss = triplet_loss.sum() / num_valid_triplets
            return loss
        
# ArcFace
class ArcFaceHead(nn.Module):
    """
    ArcFace Lossの定義

    Parameters
    ----------
    in_features : int
        入力特徴量の次元数
    out_features : int
        出力クラス数
    s : float
        ArcFace Lossのスケーリング値
    m : float
        ArcFace Lossのマージン
    easy_margin : bool
        Easy Marginを使用するかどうか
    Returns
    -------
    loss : torch.Tensor
        ArcFace Lossの値
    """
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcFaceHead, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        """
        ArcFace Lossの計算

        Parameters
        ----------
        input : torch.Tensor
            入力特徴量の埋め込みベクトル
        label : torch.Tensor
            各埋め込みに対応するラベル
        Returns
        -------
        output : torch.Tensor
            ArcFace Lossの値
        """
        cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # 入力特徴量と重みの内積計算
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) # 正規化されたコサイン値からサイン値を計算
        phi = cosine * self.cos_m - sine * self.sin_m # ArcFaceの角度変換
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine) # Easy Marginの場合,コサイン値が0より大きい場合はそのまま使用
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm) # 通常のマージンの場合,閾値以下のコサイン値に対してマージンを適用

        one_hot = torch.zeros(cosine.size(), device='cuda') 
        one_hot.scatter_(1, label.view(-1, 1).long(), 1) # one-hotエンコーディングの作成

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # ラベルに基づいて出力を調整
        output *= self.s # スケーリング
        # print(output)

        return output

# CosFace
class CosFaceHead(nn.Module):
    """
    CosFace Lossの定義

    Parameters
    ----------
    in_features : int
        入力特徴量の次元数
    out_features : int
        出力クラス数
    s : float
        CosFace Lossのスケーリング値
    m : float
        CosFace Lossのマージン
    Returns
    -------
    loss : torch.Tensor
        CosFace Lossの値
    """
    def __init__(self, in_features, out_features, s=30.0, m=0.40):
        super(CosFaceHead, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, input, label):
        """
        CosFace Lossの計算

        Parameters
        ----------
        input : torch.Tensor
            入力特徴量の埋め込みベクトル
        label : torch.Tensor
            各埋め込みに対応するラベル
        Returns
        -------
        output : torch.Tensor
            CosFace Lossの値
        """
        cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # 入力特徴量と重みの内積計算
        phi = cosine - self.m # CosFaceのマージンを適用

        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1) # one-hotエンコーディングの作成

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # ラベルに基づいて出力を調整
        output *= self.s # スケーリング
        # print(output)

        return output

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'in_features=' + str(self.in_features) \
               + ', out_features=' + str(self.out_features) \
               + ', s=' + str(self.s) \
               + ', m=' + str(self.m) + ')'

# SphereFace
class SphereFaceHead(nn.Module):
    """
    SphereFace Lossの定義

    Parameters
    ----------
    in_features : int
        入力特徴量の次元数
    out_features : int
        出力クラス数
    m : float
        SphereFace Lossのマージン
    Returns
    -------
    loss : torch.Tensor
        SphereFace Lossの値
    """
    def __init__(self, in_features, out_features, m=4):
        super(SphereFaceHead, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.m = m
        self.base = 1000.0
        self.gamma = 0.12
        self.power = 1
        self.LambdaMin = 5.0
        self.iter = 0
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform(self.weight)

        # duplication formula
        self.mlambda = [
            lambda x: x ** 0,
            lambda x: x ** 1,
            lambda x: 2 * x ** 2 - 1,
            lambda x: 4 * x ** 3 - 3 * x,
            lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
            lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
        ]

    def forward(self, input, label):
        """
        SphereFace Lossの計算

        Parameters
        ----------
        input : torch.Tensor
            入力特徴量の埋め込みベクトル
        label : torch.Tensor
            各埋め込みに対応するラベル
        Returns
        -------
        output : torch.Tensor
            SphereFace Lossの値
        """
        self.iter += 1 
        self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power)) # スケーリング係数の計算

        cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) # 入力特徴量と重みの内積計算
        cos_theta = cos_theta.clamp(-1, 1) # コサイン値を[-1, 1]に制限
        cos_m_theta = self.mlambda[self.m](cos_theta) # マージンを適用
        theta = cos_theta.data.acos() # コサイン値から角度を計算
        k = (self.m * theta / 3.14159265).floor() # マージンに基づく係数を計算
        phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k # マージンを適用したコサイン値の計算
        NormOfFeature = torch.norm(input, 2, 1) # 入力特徴量のノルムを計算

        one_hot = torch.zeros(cos_theta.size()) 
        one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot # one-hotエンコーディングの作成
        one_hot.scatter_(1, label.view(-1, 1), 1) # one-hotエンコーディングをラベルに基づいて作成

        output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta # ラベルに基づいて出力を調整
        output *= NormOfFeature.view(-1, 1) # 入力特徴量のノルムでスケーリング

        return output

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'in_features=' + str(self.in_features) \
               + ', out_features=' + str(self.out_features) \
               + ', m=' + str(self.m) + ')'

2.3 method_config.py

各手法ごとにモデル・損失関数・Optimizer・学習/検証関数を組み立てるファクトリ.
ここでは,Siamese Networkのみ記載していますが,他手法も同様

def build_metric_method(method, arch, lr, weight_decay, optim_momentum, num_classes , num_dim, margin, scale, hard_triplets, easy_margin, device):
    """
    Metric Learning Methodを構築する関数

    Parameters
    ----------
    method : str
        使用するメトリック学習手法の名前
    arch : str
        使用するアーキテクチャの名前
    lr : float
        学習率
    weight_decay : float
        重み減衰
    optim_momentum : float
        オプティマイザのモメンタム
    num_classes : int
        クラス数
    num_dim : int
        特徴量の次元数
    margin : float
        マージンの値
    scale : float
        スケーリングの値
    hard_triplets : bool
        ハードトリプレットを使用するかどうか
    easy_margin : bool
        イージーマージンを使用するかどうか
    device : torch.device
        使用するデバイス(CPUまたはGPU)
    Returns
    -------
    model : MetricModel
        構築されたメトリック学習モデル
    metric_loss_func : nn.Module
        メトリック学習の損失関数
    optimizer : torch.optim.Optimizer
        モデルのオプティマイザ 
    train_fn : function
        学習用の関数
    validation_fn : function
        検証用の関数
    """
    if method == 'SiameseNetwork': # Siamese Networkを使用する場合
        model = MetricModel(method=method, arch=arch, num_classes=num_classes).to(device) # モデルの初期化
        metric_loss_func = ContrastiveLoss(margin=margin).to(device) # 損失関数の初期化
        optimizer = optim.Adam(model.parameters(), lr=lr) # オプティマイザの初期化

        return model, metric_loss_func, optimizer, train_metric, validation

2.4 trainer.py

学習ループと,検証時に t-SNE で特徴空間を可視化する save_map フックを実装.

  1. バッチごとに特徴量とラベルを収集
  2. sum loss & 正解数カウント
  3. 最後に特徴量を concat → return features_np, labels_np
def validation(device, val_loader, model, ce_loss_func, metric_loss_func):
    """
    検証用の関数

    Parameters
    ----------
    device : torch.device
        使用するデバイス(CPUまたはGPU)
    val_loader : torch.utils.data.DataLoader
        検証データローダー
    model : nn.Module
        学習するモデル
    ce_loss_func : nn.Module
        クロスエントロピー損失関数
    metric_loss_func : nn.Module
        メトリック学習の損失関数
    Returns
    -------
    sum_loss : float
        検証データに対する損失の合計
    count : int
        正しく分類されたサンプルの数
    features_np : np.ndarray
        特徴量のNumPy配列
    labels_np : np.ndarray
        ラベルのNumPy配列
    """
    model.eval()
    
    sum_loss = 0.0
    count = 0
    
    features_list = []
    labels_list = []

    with torch.no_grad():
        for idx, (imgs, labels) in enumerate(tqdm(val_loader)):
            imgs = imgs.to(device, non_blocking=True).float() # 画像をデバイスに転送
            labels = labels.to(device, non_blocking=True).long() # ラベルをデバイスに転送
            
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                features, logits = model(imgs) # モデルの順伝播
                loss = ce_loss_func(logits, labels) # クロスエントロピー損失の計算

                features_list.append(features.cpu().detach().numpy()) # 特徴量をリストに追加
                labels_list.append(labels.cpu().detach().numpy()) # ラベルをリストに追加

            sum_loss += loss.item()
            count += torch.sum(logits.argmax(dim=1) == labels).item()

    features_np = np.concatenate(features_list, 0) # 特徴量のNumPy配列を作成
    labels_np = np.concatenate(labels_list, 0) # ラベルのNumPy配列を作成

    return sum_loss, count, features_np, labels_np

2.5 train_metric.py

コマンドライン引数で設定ファイルを読み込み,ループ内で学習・検証,ログ/モデル/t-SNEを保存します.

python train_metric.py --config_path Config/TripletLoss.py

毎エポックごとにt-SNEを保存.

if vis_featspace: # 特徴空間の可視化を行う場合
    save_featspace_path = featspace_path + str(epoch + 1) + '.png' # 特徴空間の保存パス
    save_featspace(features_np, labels_np, class_names, save_featspace_path) # 特徴空間の保存

3. 設定ファイルと実行方法

3.1 設定ファイル例(TripletLoss)

config = {
    "epoch": 20,
    "batch_size": 128,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "momentum": 0.9,
    "img_size": 224,
    "dataset": "cifar10",
    "margin": 1.0,
    "scale": None,
    "method": "TripletLoss",
    "arch": "ResNet18",
    "_lambda": 1.0,
    "num_dim": None,
    "hard_triplets": None,
    "easy_margin": None,
    "vis_featspace": True
}

3.2 実行例と結果

python train_metric.py --config_path Config/TripletLoss.py
手法 最高精度
SiameseNetwork 0.9253
TripletLoss 0.9535
TripletLoss+Hard 0.9575
ArcFace 0.9167
CosFace 0.9205
SphereFace 0.9240

最も高い精度を記録した手法は:TripletLoss+Hard(0.9575) です.
ArcFace/CosFace/SphereFace はハイパラ(lr/scale/margin)のチューニングが鍵になります.

3.3. t-SNE による特徴空間の可視化

vis_featspace=True とした場合,10エポックごとに以下のような図を保存します.

Epoch 10
10.png

Epoch 20
20.png

クラス間の分離が明瞭に特徴量が意味的に集約されていることが視覚的に確認できる

4. おわりに

他データセットやバックボーンへの置き換えが容易です.
Warmup や学習率スケジューラなどを組み合わせると精度がさらに向上します.

GitHub リポジトリ:https://github.com/SyunkiTakase/Metric_Learning_Method

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?