metric learningにおけるangular lossとn-pair lossの特性の違いを理解したかったので、pytorchで実装し、mnistで分類してからt-SNEで可視化してみました。
実装はgithubにあります。
https://github.com/tomp11/metric_learning
また、N-Pair LossとAngular Lossについては
距離学習におけるN-Pair LossとAngular Lossの理解と実装(Pytorch)
でも詳しく説明しているのでそちらも参考にしてください。
#結果
最初に結果を比較してみます。
n_pair_loss | n_pair_angular_loss |
---|---|
n-pair lossは少し曖昧に分ける印象があります。3と8が重なってしまってる部分も見られます。しかし分類が難しいものでもそのままにせず、どこかのグループには所属しているようです。
angular lossはわりとはっきりとわけているように見えます。n-pair lossで混同しがちだった3と8も遠く離れています。ところが難しいデータは追いやっている傾向が高そうです。実際今回このデータにおいてはaccurancyはn-pair lossよりは劣りました。
しかし単純なモデルなので今回だけの結果だけでは良し悪しは決めれません。またt-SNEは高次元データを2次元に無理やり押し込んでいるので、参考程度にしてください。
#実装
通常のネットワークと違う特徴的なところを解説していきます。
###Dataset
class N_Pair_ImageDataset(torch.utils.data.Dataset):
def __init__(self, base_path, filenames_filename, n_pair_file_name, transform,
loader=default_image_loader):
self.base_path = base_path
self.filenamelist = []
for line in open(filenames_filename):
self.filenamelist.append(line.rstrip('\n'))
paths = []
for line in open(n_pair_file_name):
paths.append(([i for i in line.split(",")[0].split()], [i for i in line.split(",")[1].split()])) # ([anchors],[positives])
self.paths = paths
self.transform = transform
self.loader = loader
def __getitem__(self, index):
def path2img(path):
img = self.loader(os.path.join(self.base_path,self.filenamelist[int(path)]))
return img
anchor_imgs = [self.transform(path2img(path)) for path in self.paths[index][0]]
positives_imgs = [self.transform(path2img(path)) for path in self.paths[index][1]]
anchor_imgs , positives_imgs = torch.stack(anchor_imgs), torch.stack(positives_imgs)
return anchor_imgs, positives_imgs
2組を一度に入れるために
898 8564 4870 5366 6851 ,33 8777 5010 5553 6700
1544 3567 5438 4221 570 ,1455 3922 5225 4497 732
3483 9249 2897 7654 5723 ,3218 9619 2217 7307 5798
のような画像のパスに対応したインデックスが羅列されたテキストファイルをSampling.py
で作り出し、
一行づつ読み込んでいます。
###model
class N_PAIR_net(nn.Module):
def __init__(self, embeddingnet):
super(N_PAIR_net, self).__init__()
self.embeddingnet = embeddingnet
def forward(self, anchors, positives):
f = self.embeddingnet(anchors)
f_p = self.embeddingnet(positives)
return f, f_p
n-pairの距離学習では出力を2組出さなければいけないので2組の入力をおなじネットワークに入れます。
###loss
lossについては以下も参考にしてみてください。
距離学習におけるN-Pair LossとAngular Lossの理解と実装(Pytorch)
class Angular_mc_loss(nn.Module):
def __init__(self, alpha=45, in_degree=True):
super(Angular_mc_loss, self).__init__()
if in_degree:
alpha = np.deg2rad(alpha)
self.sq_tan_alpha = np.tan(alpha) ** 2
def forward(self, f, f_p, with_npair=True, lamb=2):
n_pairs = len(f)
term1 = 4 * self.sq_tan_alpha * torch.matmul(f + f_p, torch.transpose(f_p, 0, 1))
term2 = 2 * (1 + self.sq_tan_alpha) * torch.sum(f * f_p, keepdim=True, dim=1)
f_apn = term1 - term2
mask = torch.ones_like(f_apn) - torch.eye(n_pairs).cuda()
f_apn = f_apn * mask
loss = torch.mean(torch.logsumexp(f_apn, dim=1))
if with_npair:
loss_npair = self.n_pair_mc_loss(f, f_p)
loss = loss_npair + lamb*loss
return loss
@staticmethod
def n_pair_mc_loss(f, f_p):
n_pairs = len(f)
term1 = torch.matmul(f, torch.transpose(f_p, 0, 1))
term2 = torch.sum(f * f_p, keepdim=True, dim=1)
f_apn = term1 - term2
mask = torch.ones_like(f_apn) - torch.eye(n_pairs).cuda()
f_apn = f_apn * mask
return torch.mean(torch.logsumexp(f_apn, dim=1))
with_npair
がtrueのときはn-pair lossと組み合わせたlossを使うことができます。(論文ではNL&ALと書かれているもの)
###t_SNE
def t_sne():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='./mnist_test', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=1000, shuffle=True, num_workers=4)
test_set = datasets.MNIST(root='./mnist_test', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_batch = []
target = []
count = 0
for i, j in test_set:
test_batch.append(i)
target.append(j)
count+=1
# print(i.size())
if count==5000:
break
data = torch.stack(test_batch) # (5000,1,28,28)
target = torch.Tensor(target) # (5000)
net = Cnn_32()
model = Test_net(net).to(device)
model.load_state_dict(torch.load('./checkpoints/checkpoint.pth.tar'))
latent_vecs = auto_encode(model, device, data)
latent_vecs, target = latent_vecs.to("cpu"), target.to("cpu")
latent_vecs, target = latent_vecs.numpy(), target.numpy()
print(latent_vecs.shape, target.shape)
latent_vecs_reduced = TSNE(n_components=2, random_state=0).fit_transform(latent_vecs)
plt.scatter(latent_vecs_reduced[:, 0], latent_vecs_reduced[:, 1],
c=target, cmap='jet')
plt.colorbar()
plt.show()
scikit-learnに便利なt-sneがあるのでそれを使っています。(速度は遅いようですが...)
難しいことはよくわかりませんがとりあえずTSNE関数に特徴ベクトルをぶち込みます。
#参考
https://copypaste-ds.hatenablog.com/entry/2019/03/01/164155
http://inaz2.hatenablog.com/entry/2017/01/24/211331