実装編 ― PyTorchで試す代表的計量学習手法5選
第1回「基礎編」で理論を解説した後は,実際のコードを動かして結果を確かめてみましょう.本記事では簡単なコード解説と例の提示,実装結果の紹介を行います.
1. リポジトリ構成 & セットアップ
以下のコマンドでGithubから,リポジトリをクローンする.
git clone https://github.com/SyunkiTakase/Metric_Learning_Method.git
cd Metric_Learning_Method
リポジトリの構成は以下のようになっている.
Metric_Learning_Method/
├── Config/
│ ├── SiameseNetwork.py
│ ├── TripletLoss.py
│ ├── ArcFace.py
│ ├── CosFace.py
│ └── SphereFace.py
├── base_model/
│ └── Xception.py
├── fig/ ← 図解
├── method_config.py ← 手法ごとの組み立てファクトリ
├── metric_model.py ← Encoder+Head 動的読み込み
├── metric_loss.py ← Contrastive/Triplet/ArcFace/CosFace/SphereFace
├── trainer.py ← 学習・検証ループ + t-SNE フック
├── train_metric.py ← エントリポイント
├── utils.py ← load_config / save_model / save_log / save_featspace
└── README.md
2. コアモジュール解説
2.1 metric_model.py
バックボーン(ResNet/VGG/Xception)を動的にロードして,特徴量(embeddings) → 出力(logits)を返すモジュール.
class MetricModel(nn.Module):
"""
メトリック学習モデルの定義
Parameters
----------
method : str
使用するメトリック学習手法の名前
arch : str
使用するアーキテクチャの名前
num_dim : int
特徴量の次元数
num_classes : int
分類クラス数
Returns
-------
None
"""
def __init__(self, method='SiameseNetwork', arch='ResNet18', num_dim=512, num_classes=10):
super(MetricModel, self).__init__()
self.method = method
self.arch = arch
self.num_dim = num_dim
self.num_classes = num_classes
self.selected_arch(arch=self.arch)
last_layers = ['fc', 'classifier', 'heads'] # 最終層の名前リスト
for layer in last_layers: # 最終層の名前を順に確認
if hasattr(self.encoder, layer):
# 最終層がSequentialの場合
if isinstance(getattr(self.encoder, layer), nn.Sequential):
self.out_enc_dim = getattr(self.encoder, layer)[0].in_features
else:
self.out_enc_dim = getattr(self.encoder, layer).in_features
break
for layer in last_layers: # 最終層の名前を順に確認
if hasattr(self.encoder, layer):
# 最終層がSequentialの場合
if isinstance(getattr(self.encoder, layer), nn.Sequential):
setattr(self.encoder, layer, nn.Sequential(nn.Identity()))
else:
setattr(self.encoder, layer, nn.Identity())
break
if self.method == 'ArcFace' or self.method == 'CosFace' or self.method == 'SphereFace': # ArcFace/CosFace/SphereFaceを使用する場合
self.classifier = nn.Linear(self.out_enc_dim, self.num_dim) # 特徴量の次元をnum_dimに変更
else: # その他の手法を使用する場合
self.classifier = nn.Linear(self.out_enc_dim, self.num_classes) # 分類クラス数に変更
def forward(self, x):
"""
順伝播の定義
Parameters
----------
x : torch.Tensor
入力画像のテンソル
Returns
-------
feat : torch.Tensor
特徴量
y : torch.Tensor
分類結果
"""
feat = self.encoder(x)
y = self.classifier(feat)
return feat, y
def selected_arch(self, arch):
"""
使用するアーキテクチャを選択する
Parameters
----------
arch : str
使用するアーキテクチャの名前
Returns
-------
None
"""
arch_map = {
'ResNet18': resnet18,
'ResNet34': resnet34,
'ResNet50': resnet50,
'ResNet152': resnet152,
'VGG11': vgg11,
'VGG13': vgg13,
'VGG16': vgg16,
'VGG19': vgg19,
}
if arch in arch_map:
self.encoder = arch_map[arch](weights='IMAGENET1K_V1')
else:
module_name = f'base_model.{arch}'
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise ValueError(f'Unsupported architecture: {arch} (module {module_name} not found)') from e
try:
model_cls = getattr(module, arch)
except AttributeError as e:
raise ValueError(f'Module {module_name} does not define class {arch}') from e
self.encoder = model_cls()
2.2 metric_loss.py
Contrastive, Triplet の距離系と,ArcFace/CosFace/SphereFace のマージン付き分類ヘッドを実装.
# Contrastive Loss
class ContrastiveLoss(nn.Module):
"""
Contrastive Lossの定義
Parameters
----------
margin : float
Contrastive Lossのマージン
Returns
-------
loss : torch.Tensor
Contrastive Lossの値
"""
def __init__(self, margin=1.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, embeddings, labels):
"""
Contrastive Lossの計算
Parameters
----------
embeddings : torch.Tensor
入力特徴量の埋め込みベクトル
labels : torch.Tensor
各埋め込みに対応するラベル
Returns
-------
loss : torch.Tensor
Contrastive Lossの値
"""
pairwise_dist = torch.cdist(embeddings, embeddings, p=2) # 特徴量間のペアワイズ距離の計算
positive_mask = (labels.unsqueeze(1) == labels.unsqueeze(0)).float() # 同じラベルのマスク(ポジティブ)
negative_mask = (labels.unsqueeze(1) != labels.unsqueeze(0)).float() # 異なるラベルのマスク(ネガティブ)
positive_dist = pairwise_dist * positive_mask
negative_dist = pairwise_dist * negative_mask
# 各ペアの数をカウント
num_pos_pairs = positive_mask.sum().item()
num_neg_pairs = negative_mask.sum().item()
if num_pos_pairs == 0 or num_neg_pairs == 0:
return torch.tensor(0.0)
# Contrastive Loss
loss = (
(positive_dist.pow(2) / 2).mean() +
(F.relu(self.margin - negative_dist + 1e-9).pow(2) / 2).mean()
)
return loss
# Triplet Loss
class TripletLoss(nn.Module):
"""
Triplet Lossの定義
Parameters
----------
margin : float
Triplet Lossのマージン
hard_triplets : bool
Hard Tripletを使用するかどうか
Returns
-------
loss : torch.Tensor
Triplet Lossの値
"""
def __init__(self, margin=1.0, hard_triplets=False):
super(TripletLoss, self).__init__()
self.margin = margin
self.use_hard_triplets = hard_triplets
def forward(self, embeddings, labels):
"""
Triplet Lossの計算
Parameters
----------
embeddings : torch.Tensor
入力特徴量の埋め込みベクトル
labels : torch.Tensor
各埋め込みに対応するラベル
Returns
-------
loss : torch.Tensor
Triplet Lossの値
"""
pairwise_dist = torch.cdist(embeddings, embeddings, p=2) # 特徴量間のペアワイズ距離の計算
positive_mask = (labels.unsqueeze(1) == labels.unsqueeze(0)).float() # ポジティブペアのマスク
negative_mask = (labels.unsqueeze(1) != labels.unsqueeze(0)).float() # ネガティブペアのマスク
# print('Pos Mask:', positive_mask)
# print('Neg Mask:', negative_mask)
if self.use_hard_triplets: # Hard PositiveおよびHard Negativeを選択
positive_dist = pairwise_dist * positive_mask # ポジティブペアの距離
positive_dist = positive_dist + (1 - positive_mask) * -1e6
hardest_positive_dist, _ = positive_dist.max(dim=1) # 各アンカーに対する最も遠いポジティブ
negative_dist = pairwise_dist + (1 - negative_mask) * 1e6 # 無効なネガティブに大きな値を設定
hardest_negative_dist, _ = negative_dist.min(dim=1) # 各アンカーに対する最も近いネガティブ
# Triplet Loss
loss = torch.relu(hardest_positive_dist - hardest_negative_dist + self.margin)
return loss.mean()
else: # 全ペアを考慮
positive_dist = pairwise_dist * positive_mask # ポジティブペアの距離
# print('Pos Dist:', positive_dist)
negative_dist = pairwise_dist * negative_mask # ネガティブペアの距離
# print('Neg Dist:', negative_dist)
# Triplet Loss
triplet_loss = positive_dist.unsqueeze(2) - negative_dist.unsqueeze(1) + self.margin
triplet_loss = torch.relu(triplet_loss) # マージンに基づくReLU適用
valid_triplets = positive_mask.unsqueeze(2) * negative_mask.unsqueeze(1) # 有効なTripletのマスク
triplet_loss = triplet_loss * valid_triplets
num_valid_triplets = valid_triplets.sum() + 1e-16 # 有効ペア数(ゼロ除算を防ぐ)
loss = triplet_loss.sum() / num_valid_triplets
return loss
# ArcFace
class ArcFaceHead(nn.Module):
"""
ArcFace Lossの定義
Parameters
----------
in_features : int
入力特徴量の次元数
out_features : int
出力クラス数
s : float
ArcFace Lossのスケーリング値
m : float
ArcFace Lossのマージン
easy_margin : bool
Easy Marginを使用するかどうか
Returns
-------
loss : torch.Tensor
ArcFace Lossの値
"""
def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
super(ArcFaceHead, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, input, label):
"""
ArcFace Lossの計算
Parameters
----------
input : torch.Tensor
入力特徴量の埋め込みベクトル
label : torch.Tensor
各埋め込みに対応するラベル
Returns
-------
output : torch.Tensor
ArcFace Lossの値
"""
cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # 入力特徴量と重みの内積計算
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) # 正規化されたコサイン値からサイン値を計算
phi = cosine * self.cos_m - sine * self.sin_m # ArcFaceの角度変換
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine) # Easy Marginの場合,コサイン値が0より大きい場合はそのまま使用
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm) # 通常のマージンの場合,閾値以下のコサイン値に対してマージンを適用
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, label.view(-1, 1).long(), 1) # one-hotエンコーディングの作成
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # ラベルに基づいて出力を調整
output *= self.s # スケーリング
# print(output)
return output
# CosFace
class CosFaceHead(nn.Module):
"""
CosFace Lossの定義
Parameters
----------
in_features : int
入力特徴量の次元数
out_features : int
出力クラス数
s : float
CosFace Lossのスケーリング値
m : float
CosFace Lossのマージン
Returns
-------
loss : torch.Tensor
CosFace Lossの値
"""
def __init__(self, in_features, out_features, s=30.0, m=0.40):
super(CosFaceHead, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
def forward(self, input, label):
"""
CosFace Lossの計算
Parameters
----------
input : torch.Tensor
入力特徴量の埋め込みベクトル
label : torch.Tensor
各埋め込みに対応するラベル
Returns
-------
output : torch.Tensor
CosFace Lossの値
"""
cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # 入力特徴量と重みの内積計算
phi = cosine - self.m # CosFaceのマージンを適用
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, label.view(-1, 1).long(), 1) # one-hotエンコーディングの作成
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # ラベルに基づいて出力を調整
output *= self.s # スケーリング
# print(output)
return output
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features=' + str(self.in_features) \
+ ', out_features=' + str(self.out_features) \
+ ', s=' + str(self.s) \
+ ', m=' + str(self.m) + ')'
# SphereFace
class SphereFaceHead(nn.Module):
"""
SphereFace Lossの定義
Parameters
----------
in_features : int
入力特徴量の次元数
out_features : int
出力クラス数
m : float
SphereFace Lossのマージン
Returns
-------
loss : torch.Tensor
SphereFace Lossの値
"""
def __init__(self, in_features, out_features, m=4):
super(SphereFaceHead, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.m = m
self.base = 1000.0
self.gamma = 0.12
self.power = 1
self.LambdaMin = 5.0
self.iter = 0
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform(self.weight)
# duplication formula
self.mlambda = [
lambda x: x ** 0,
lambda x: x ** 1,
lambda x: 2 * x ** 2 - 1,
lambda x: 4 * x ** 3 - 3 * x,
lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
]
def forward(self, input, label):
"""
SphereFace Lossの計算
Parameters
----------
input : torch.Tensor
入力特徴量の埋め込みベクトル
label : torch.Tensor
各埋め込みに対応するラベル
Returns
-------
output : torch.Tensor
SphereFace Lossの値
"""
self.iter += 1
self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power)) # スケーリング係数の計算
cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) # 入力特徴量と重みの内積計算
cos_theta = cos_theta.clamp(-1, 1) # コサイン値を[-1, 1]に制限
cos_m_theta = self.mlambda[self.m](cos_theta) # マージンを適用
theta = cos_theta.data.acos() # コサイン値から角度を計算
k = (self.m * theta / 3.14159265).floor() # マージンに基づく係数を計算
phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k # マージンを適用したコサイン値の計算
NormOfFeature = torch.norm(input, 2, 1) # 入力特徴量のノルムを計算
one_hot = torch.zeros(cos_theta.size())
one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot # one-hotエンコーディングの作成
one_hot.scatter_(1, label.view(-1, 1), 1) # one-hotエンコーディングをラベルに基づいて作成
output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta # ラベルに基づいて出力を調整
output *= NormOfFeature.view(-1, 1) # 入力特徴量のノルムでスケーリング
return output
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features=' + str(self.in_features) \
+ ', out_features=' + str(self.out_features) \
+ ', m=' + str(self.m) + ')'
2.3 method_config.py
各手法ごとにモデル・損失関数・Optimizer・学習/検証関数を組み立てるファクトリ.
ここでは,Siamese Networkのみ記載していますが,他手法も同様
def build_metric_method(method, arch, lr, weight_decay, optim_momentum, num_classes , num_dim, margin, scale, hard_triplets, easy_margin, device):
"""
Metric Learning Methodを構築する関数
Parameters
----------
method : str
使用するメトリック学習手法の名前
arch : str
使用するアーキテクチャの名前
lr : float
学習率
weight_decay : float
重み減衰
optim_momentum : float
オプティマイザのモメンタム
num_classes : int
クラス数
num_dim : int
特徴量の次元数
margin : float
マージンの値
scale : float
スケーリングの値
hard_triplets : bool
ハードトリプレットを使用するかどうか
easy_margin : bool
イージーマージンを使用するかどうか
device : torch.device
使用するデバイス(CPUまたはGPU)
Returns
-------
model : MetricModel
構築されたメトリック学習モデル
metric_loss_func : nn.Module
メトリック学習の損失関数
optimizer : torch.optim.Optimizer
モデルのオプティマイザ
train_fn : function
学習用の関数
validation_fn : function
検証用の関数
"""
if method == 'SiameseNetwork': # Siamese Networkを使用する場合
model = MetricModel(method=method, arch=arch, num_classes=num_classes).to(device) # モデルの初期化
metric_loss_func = ContrastiveLoss(margin=margin).to(device) # 損失関数の初期化
optimizer = optim.Adam(model.parameters(), lr=lr) # オプティマイザの初期化
return model, metric_loss_func, optimizer, train_metric, validation
2.4 trainer.py
学習ループと,検証時に t-SNE で特徴空間を可視化する save_map フックを実装.
- バッチごとに特徴量とラベルを収集
- sum loss & 正解数カウント
- 最後に特徴量を concat → return features_np, labels_np
def validation(device, val_loader, model, ce_loss_func, metric_loss_func):
"""
検証用の関数
Parameters
----------
device : torch.device
使用するデバイス(CPUまたはGPU)
val_loader : torch.utils.data.DataLoader
検証データローダー
model : nn.Module
学習するモデル
ce_loss_func : nn.Module
クロスエントロピー損失関数
metric_loss_func : nn.Module
メトリック学習の損失関数
Returns
-------
sum_loss : float
検証データに対する損失の合計
count : int
正しく分類されたサンプルの数
features_np : np.ndarray
特徴量のNumPy配列
labels_np : np.ndarray
ラベルのNumPy配列
"""
model.eval()
sum_loss = 0.0
count = 0
features_list = []
labels_list = []
with torch.no_grad():
for idx, (imgs, labels) in enumerate(tqdm(val_loader)):
imgs = imgs.to(device, non_blocking=True).float() # 画像をデバイスに転送
labels = labels.to(device, non_blocking=True).long() # ラベルをデバイスに転送
with torch.autocast(device_type="cuda", dtype=torch.float16):
features, logits = model(imgs) # モデルの順伝播
loss = ce_loss_func(logits, labels) # クロスエントロピー損失の計算
features_list.append(features.cpu().detach().numpy()) # 特徴量をリストに追加
labels_list.append(labels.cpu().detach().numpy()) # ラベルをリストに追加
sum_loss += loss.item()
count += torch.sum(logits.argmax(dim=1) == labels).item()
features_np = np.concatenate(features_list, 0) # 特徴量のNumPy配列を作成
labels_np = np.concatenate(labels_list, 0) # ラベルのNumPy配列を作成
return sum_loss, count, features_np, labels_np
2.5 train_metric.py
コマンドライン引数で設定ファイルを読み込み,ループ内で学習・検証,ログ/モデル/t-SNEを保存します.
python train_metric.py --config_path Config/TripletLoss.py
毎エポックごとにt-SNEを保存.
if vis_featspace: # 特徴空間の可視化を行う場合
save_featspace_path = featspace_path + str(epoch + 1) + '.png' # 特徴空間の保存パス
save_featspace(features_np, labels_np, class_names, save_featspace_path) # 特徴空間の保存
3. 設定ファイルと実行方法
3.1 設定ファイル例(TripletLoss)
config = {
"epoch": 20,
"batch_size": 128,
"lr": 1e-3,
"weight_decay": 1e-4,
"momentum": 0.9,
"img_size": 224,
"dataset": "cifar10",
"margin": 1.0,
"scale": None,
"method": "TripletLoss",
"arch": "ResNet18",
"_lambda": 1.0,
"num_dim": None,
"hard_triplets": None,
"easy_margin": None,
"vis_featspace": True
}
3.2 実行例と結果
python train_metric.py --config_path Config/TripletLoss.py
手法 | 最高精度 |
---|---|
SiameseNetwork | 0.9253 |
TripletLoss | 0.9535 |
TripletLoss+Hard | 0.9575 |
ArcFace | 0.9167 |
CosFace | 0.9205 |
SphereFace | 0.9240 |
最も高い精度を記録した手法は:TripletLoss+Hard(0.9575) です.
ArcFace/CosFace/SphereFace はハイパラ(lr/scale/margin)のチューニングが鍵になります.
3.3. t-SNE による特徴空間の可視化
vis_featspace=True とした場合,10エポックごとに以下のような図を保存します.
クラス間の分離が明瞭に特徴量が意味的に集約されていることが視覚的に確認できる
4. おわりに
他データセットやバックボーンへの置き換えが容易です.
Warmup や学習率スケジューラなどを組み合わせると精度がさらに向上します.
GitHub リポジトリ:https://github.com/SyunkiTakase/Metric_Learning_Method