11
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

mnistを距離学習(metric learning)してt-SNEで可視化してみる(pytorch)

Posted at

metric learningにおけるangular lossn-pair lossの特性の違いを理解したかったので、pytorchで実装し、mnistで分類してからt-SNEで可視化してみました。
実装はgithubにあります。
https://github.com/tomp11/metric_learning
また、N-Pair LossとAngular Lossについては
距離学習におけるN-Pair LossとAngular Lossの理解と実装(Pytorch)
でも詳しく説明しているのでそちらも参考にしてください。
#結果
最初に結果を比較してみます。

n_pair_loss n_pair_angular_loss
n_pair.jpg angular.jpg

n-pair lossは少し曖昧に分ける印象があります。3と8が重なってしまってる部分も見られます。しかし分類が難しいものでもそのままにせず、どこかのグループには所属しているようです。
angular lossはわりとはっきりとわけているように見えます。n-pair lossで混同しがちだった3と8も遠く離れています。ところが難しいデータは追いやっている傾向が高そうです。実際今回このデータにおいてはaccurancyはn-pair lossよりは劣りました。
しかし単純なモデルなので今回だけの結果だけでは良し悪しは決めれません。またt-SNEは高次元データを2次元に無理やり押し込んでいるので、参考程度にしてください。

#実装
通常のネットワークと違う特徴的なところを解説していきます。

###Dataset

Dataset.py
class N_Pair_ImageDataset(torch.utils.data.Dataset):
    def __init__(self, base_path, filenames_filename, n_pair_file_name, transform,
                 loader=default_image_loader):
        self.base_path = base_path
        self.filenamelist = []
        for line in open(filenames_filename):
            self.filenamelist.append(line.rstrip('\n'))
        paths = []
        for line in open(n_pair_file_name):
            paths.append(([i for i in line.split(",")[0].split()], [i for i in line.split(",")[1].split()])) # ([anchors],[positives])
        self.paths = paths
        self.transform = transform
        self.loader = loader
    def __getitem__(self, index):
        def path2img(path):
            img = self.loader(os.path.join(self.base_path,self.filenamelist[int(path)]))
            return img

        anchor_imgs = [self.transform(path2img(path)) for path in self.paths[index][0]]
        positives_imgs = [self.transform(path2img(path)) for path in self.paths[index][1]]
        anchor_imgs , positives_imgs = torch.stack(anchor_imgs), torch.stack(positives_imgs)
        return anchor_imgs, positives_imgs

2組を一度に入れるために
898 8564 4870 5366 6851 ,33 8777 5010 5553 6700
1544 3567 5438 4221 570 ,1455 3922 5225 4497 732
3483 9249 2897 7654 5723 ,3218 9619 2217 7307 5798
のような画像のパスに対応したインデックスが羅列されたテキストファイルをSampling.pyで作り出し、
一行づつ読み込んでいます。

###model

Models.py
class N_PAIR_net(nn.Module):
    def __init__(self, embeddingnet):
        super(N_PAIR_net, self).__init__()
        self.embeddingnet = embeddingnet

    def forward(self, anchors, positives):
        f = self.embeddingnet(anchors)
        f_p = self.embeddingnet(positives)
        return f, f_p

n-pairの距離学習では出力を2組出さなければいけないので2組の入力をおなじネットワークに入れます。

###loss
lossについては以下も参考にしてみてください。
距離学習におけるN-Pair LossとAngular Lossの理解と実装(Pytorch)

Loss.py
class Angular_mc_loss(nn.Module):
    def __init__(self, alpha=45, in_degree=True):
        super(Angular_mc_loss, self).__init__()
        if in_degree:
            alpha = np.deg2rad(alpha)
        self.sq_tan_alpha = np.tan(alpha) ** 2

    def forward(self, f, f_p, with_npair=True, lamb=2):
        n_pairs = len(f)
        term1 = 4 * self.sq_tan_alpha * torch.matmul(f + f_p, torch.transpose(f_p, 0, 1))
        term2 = 2 * (1 + self.sq_tan_alpha) * torch.sum(f * f_p, keepdim=True, dim=1)
        f_apn = term1 - term2
        mask = torch.ones_like(f_apn) - torch.eye(n_pairs).cuda()
        f_apn = f_apn * mask
        loss = torch.mean(torch.logsumexp(f_apn, dim=1))
        if with_npair:
            loss_npair = self.n_pair_mc_loss(f, f_p)
            loss = loss_npair + lamb*loss
        return loss

    @staticmethod
    def n_pair_mc_loss(f, f_p):
        n_pairs = len(f)
        term1 = torch.matmul(f, torch.transpose(f_p, 0, 1))
        term2 = torch.sum(f * f_p, keepdim=True, dim=1)
        f_apn = term1 - term2
        mask = torch.ones_like(f_apn) - torch.eye(n_pairs).cuda()
        f_apn = f_apn * mask
        return torch.mean(torch.logsumexp(f_apn, dim=1))

with_npairがtrueのときはn-pair lossと組み合わせたlossを使うことができます。(論文ではNL&ALと書かれているもの)

###t_SNE

def t_sne():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='./mnist_test', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=1000, shuffle=True, num_workers=4)
    test_set = datasets.MNIST(root='./mnist_test', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]))

    test_batch = []
    target = []
    count = 0
    for i, j in test_set:
        test_batch.append(i)
        target.append(j)
        count+=1
        # print(i.size())
        if count==5000:
            break
    data = torch.stack(test_batch) # (5000,1,28,28)
    target = torch.Tensor(target) # (5000)

    net = Cnn_32()
    model = Test_net(net).to(device)
    model.load_state_dict(torch.load('./checkpoints/checkpoint.pth.tar'))
    latent_vecs = auto_encode(model, device, data)
    latent_vecs, target = latent_vecs.to("cpu"), target.to("cpu")
    latent_vecs, target = latent_vecs.numpy(), target.numpy()
    print(latent_vecs.shape, target.shape)
    latent_vecs_reduced = TSNE(n_components=2, random_state=0).fit_transform(latent_vecs)

    plt.scatter(latent_vecs_reduced[:, 0], latent_vecs_reduced[:, 1],
                c=target, cmap='jet')
    plt.colorbar()
    plt.show()

scikit-learnに便利なt-sneがあるのでそれを使っています。(速度は遅いようですが...)
難しいことはよくわかりませんがとりあえずTSNE関数に特徴ベクトルをぶち込みます。

#参考
https://copypaste-ds.hatenablog.com/entry/2019/03/01/164155
http://inaz2.hatenablog.com/entry/2017/01/24/211331

11
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?