LoginSignup
87
70

More than 3 years have passed since last update.

距離を近づけろ!Pytorch Metric Learningで始めるDeep Metric Learning

Last updated at Posted at 2020-12-11

NTTドコモ R&D 控え室のアドベントカレンダー12日目の記事です.

こんにちは.
NTTドコモの5年目社員の石井です.

普段の業務ではPytorchをよく扱っているのですが,2020年はPytorch界隈のエコシステムにおける発展が著しかったと感じています.

この記事ではいくつかあるエコシステムのうちの1つである Pytorch Metric Learning を紹介しながら2020年を締めていこうと思います.

本記事で扱う内容

今年は皆様にとっても変化の大きい1年だったと思います.

私も自宅でYouTubeで好きなアーティストのMVを見る機会が増えて,様々な映像作品やイラストといったデザインに魅力を感じるようになりました.
そこで,今回はデザイン性の高い画像データを用いて Deep Metric Learning の手法によるカテゴリ分類に挑戦して行こうと思います.

本記事で扱う内容としては以下になります.

  • Pytorch Metric Learningの紹介
  • Pytorch Metric Learningを用いたカテゴリ分類の実践

Pytorch Metric Learningの紹介

まず初めにPytorch Metric Leaningについて紹介していきます.

Pytorch Metric Learning とは?

Pytorch Metric LearningとはDeep Metric Learningに必要な機能をコンポーネント化して9つのモジュールとして提供しているライブラリです.

pml_logo.png

既存のソースコードに対して個々の機能を独立で利用することもできますし,各種コンポーネントを組み合わせてDeep Metric Learingにおける学習から評価までのワークフロー全てを実現することもできます.
例えば,独自でDeep Metric Learningを実装しようとした場合には自作関数によるTripletLossの定義やペアやTripletと呼ばれるモデル学習時のデータの組み合わせ作成などを自前で実装する必要があります.

ここで,Pytorch Metric Learningを利用することでこれらに必要な処理をたった数行で実装できるようにしてくれます.しかも,論文にて効果を発揮している20個以上のLoss関数や10パターン以上の学習データの組み合わせ作成方法などを提供してくれるため使わない手はないと思います.

Pytorch Metric Learningで提供する9つのモジュールのそれぞれについては以下になります.

  • Distances
  • Losses
  • Miner
  • Reducers
  • Regularizes
  • Smaplers
  • Trainers
  • Testers
  • Utils

これらの個別の解説はここでは割愛しますが,公式のドキュメントに詳細に記載してありますので興味のある方はそちらを是非見てみてください.

Pytorch Metric Learningの使い方

次にPytorch Metric Learningの基本的な使い方について解説します.

前提としてPytorch Metric Learningの基本的な使い方を理解するためには,Deep Metric Learningの仕組みについて理解しておく必要があるのでそちらについて簡単に解説していきます.

Deep Metric LearningはMetric Learningの仕組みをDeep Neural Networkによって非線形変換を通じて再現しようとする方法です.このMetric Learningは距離学習と呼ばれるように,入力されるデータを同じラベルのデータは距離が近くなるように,逆に異なるラベルは距離が遠くなるように特徴ベクトル空間を出力するように重みを学習していくアルゴリズムです.そのため,Metric Learningは分類問題のように線形分離するような直接を求めるのではなく,Embeddingの処理となることに注意してください.

そして,このMetric LearningをDeep Metric Learningにて実現するためにTriplet Networkという仕組みを用います.Triplet Networkはanchorと呼ばれるサンプルとanchorと同じラベルのanchor-positive,異なるラベルのanchor-negativeの3つを1組としたデータを入力として,モデルの出力によって得られた特徴量ベクトルをLoss関数に入力してネットワークの重みを更新していきます.

この際に利用されるLoss関数にはいくつか種類があるのですが以下の式に示すようなTripletLossと呼ばれるLoss関数が一般的に利用されます1.この辺の詳しい解説はこちらの記事が分かりやすいと思うので参考にしてみてください.

L_{triplet} = [d_{ap} - d_{an} + m]_{+} \\

d_{ap}: d(\hspace{3pt} f(x_{a}),\hspace{3pt} f(x_{p} )\hspace{3pt})\\
d_{an}: d(\hspace{3pt} f(x_{a}),\hspace{3pt} f(x_{n} )\hspace{3pt})\\
m: margin

では,これまでの話を踏まえてPytorch Metric Learningで実装する方法を解説していきます.

mojikyo45_640.gif

主に上記に示すようなプロセスで処理を実装していきます.(利用するDistanceやLossによっては多少の違いはありますが基本的な流れは共通です)

  1. MinerによるLossを計算するためのペアやTripletの組み合わせを生成する
  2. Distanceによる入力されたEmbedding空間から組み合わせ毎の距離や類似度を計算する
  3. Lossにて組み合わせ毎のLoss値を計算する
  4. 必要に応じてRegularizerを適応してLoss関数に正則化を実施する
  5. Reduceによる組み合わせ毎の複数のLoss値を単一値に変換する

これらの流れをPytorch Metric Learningで実装すると以下のような形になります.
例としてTripletLoss(Pytorch-Metric-LearingではTripletMarginLossが対応)を利用した実装方法を以下に示します.

from pytorch_metric_learning.miners import TripletMarginMiner
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.losses import TripletMarginLoss
from pytorch_metric_learning.reducers import ThresholdReducer

distance = CosineSimilarity()
reducer = ThresholdReducer(low = 0)
losser = TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
miner = TripletMarginMiner(margin=0.2, distance=distance)

# 学習ループ
for i, (data, labels) in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = model(data)
    triplets = miner(embeddings, labels)
    loss = losser(embeddings, labels, triplets)
    loss.backward()
    optimizer.step()

学習ループの中で取り出したバッチに対してモデルを通じて特徴量空間に写像した後に miner にて組み合わせを計算, losser によるコサイン類似度の計算,TripletLossの計算を実施して重みの更新を行っています.

このように簡単にDeep Metric Learningを実装することができるようになります.
さて,次からはDeep Metric Learningを用いたカテゴリ分類を画像データを用いて実践していきましょう.

Deep Metric Learningを用いたカテゴリ分類の実践

ここまででDeep Metric Learningの簡単な仕組みとPytorch Metric Learningでの実装方法は理解できていると思いますので,次におしゃれなデザイン性の高い画像データセットにてMetric Learningを活用したカテゴリ分類を実践していこうと思います.

取り組みの全体像

取り組みの全体像は以下の図のようなイメージとなります.

qiita_image_01.png

  1. デザイン性の高い画像と画像を識別するカテゴリ情報を正解データとしたデータセットを用いてDeep Metric Learningによる学習モデルを生成
  2. 生成された学習モデルに対して新規の画像を入力して得られた特徴ベクトルがどの空間にマッピングされるかを確認

データセット

利用するデータは Unsplash という海外の著作権フリーの写真素材サイトのデータとなります.こちらのサイトにはおしゃれな写真が多数掲載されており,サイト自体もシンプルで洗練されたデザインとなっていて,デザイン性の高い写真を探している人にはうってつけのサイトだと思います.
以下のようなクオリティの高い画像が手軽に利用可能です.
qiita_image_02.png

今回はこのサイトより提供されている画像データとトピックというカテゴリ情報を用いて,3つのカテゴリとそのカテゴリ毎に1000枚の画像を用意したデータセットとして利用していこうと思います.

用意したカテゴリ情報は以下になります.

ID カテゴリ名 説明 サンプル
0 people 人物を中心としたアート系の写真 https://unsplash.com/photos/RMxmvxz9tHQ
1 plant 植物や果物などの自然を中心とした写真 https://unsplash.com/photos/cfxnOUSLrgk
2 water 海などの水を中心とした写真 https://unsplash.com/photos/Sb7x-pgnsWI

モデル学習

今回の学習では以下のようなパラメータにて学習を行いました.

パラメータ 内容
損失関数 TripletLoss
距離算出方法 CosineSimilarity
最適化関数 Adam
学習率 1e-4
バッチ数 128
エポック数 50

Pytorch Metric Learningの学習に用いたコードは以下になります.

  • モデル定義
import torch
import torch.nn as nn
import torch.nn.function as F

class ArtworkModel(nn.Module):
    def __init__(self):
        super(ArtworkModel, self).__init__()        
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.fc = nn.Sequential(
            nn.Linear(451584, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 3)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

モデル定義は非常にシンプルなものとしています.
畳み込みを2層を重ねた後にプーリング層を適応し,出力されたパラメータを1次元に変換して全結合層で 128 個のパラメータを出力するように設計しています.

  • データローダー
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader

# 読み込むデータのパス
train_data_path = './data/img/train'
test_data_path  = './data/img/test'

# 画像前処理の定義
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=30),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# データローダー作成
train_dataset = datasets.ImageFolder(train_data_path, transform=transform)
test_dataset = datasets.ImageFolder(test_data_path, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

学習用のデータは data/img/train のフォルダ,テスト用のデータは data/img/test に配置しているためご自身の環境に合わせて適切に設定してください.画像はtorchvisionを用いて然るべき前処理を適応しています.

  • 学習と評価
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from pytorch_metric_learning.miners import TripletMarginMiner
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.losses import TripletMarginLoss
from pytorch_metric_learning.reducers import ThresholdReducer

# 学習用関数
def train(model, loss_func, mining_func, device, dataloader, optimizer, epoch):
    model.train() 
    for idx, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(inputs)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if idx % 10 == 0:
            print('Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}'.format(epoch, idx, loss, mining_func.num_triplets))
    print()

# テスト用関数
def test(model, dataloader, device, epoch):
    _predicted_metrics = []
    _true_labels = []
    model.eval()
    with torch.no_grad():    
        for i, (inputs,  labels) in enumerate(dataloader):
            inputs, labels = inputs.to(device), labels.to(device)
            metric = model(inputs).detach().cpu().numpy()
            metric = metric.reshape(metric.shape[0], metric.shape[1])
            _predicted_metrics.append(metric)
            _true_labels.append(labels.detach().cpu().numpy())
    return np.concatenate(_predicted_metrics), np.concatenate(_true_labels)

# パラメーター
epochs = 50
laerning_rate = 1e-4
batch_size = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ArtworkModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
test_predicted_metrics = []
test_true_labels = []

# 学習と評価
for epoch in range(1, epochs + 1):
    print('Epoch {}/{}'.format(epoch, epochs))
    print('-' * 10)
    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
    _tmp_metrics, _tmp_labels = test(model, test_loader, epoch)
    test_predicted_metrics.append(_tmp_metrics)
    test_true_labels.append(_tmp_labels)

カテゴリ分類の確認

さて,モデル学習は済んでおりますので続いてカテゴリが正確に分類されているのかを確認していきましょう.

Deep Metric Learningによって学習した今回のモデルは入力データを同じラベルは距離が近くなるように,逆に異なるラベルは距離が離れるように 128 次元の特徴量ベクトルに写像するようになっているはずです.この128次元の出力された特徴量ベクトルが正確に分けられているのかを t-SNE を用いて2次元に変換して確認します.

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

tSNE_metrics = TSNE(n_components=2, random_state=0).fit_transform(test_predicted_metrics[-1])

plt.scatter(tSNE_metrics[:, 0], tSNE_metrics[:, 1], c=test_true_labels[-1])
plt.colorbar()
plt.savefig("output.png")
plt.show()

出力された結果は以下の通りです.
2次元平面に変換することでこのように3つのカテゴリがそれっぽく分類されていることがわかります.もっとはっきり分類されていることを期待しましたが,パラメーターチューニングをしていないのでネットワーク改良と適切なパラメータを設定すればより正確に分類ができるかと思います.これを見ると,ラベル0の人物を中心とした写真とラベル1の植物を中心した写真が混在しているのに対して,ラベル2の水を中心とした画像は独立しているように見受けられますね.

qiita_image_output_01.png

最後に新規画像がどのカテゴリに近いのかを試してみましょう.
今回は私が用意した画像はこちらになります.これはマルタ共和国という国のゴザ島の写真です.
この写真がちゃんと適切なカテゴリに分類されるのかを試してみます.

IMG-5470.jpg

この画像をモデルに入力して特徴量ベクトルに変換された出力はラベル2の水を中心とした画像に分類されていることが確認できました.

qiita_image_output_02.png

Deep Metric Learing 面白い!

さいごに

Pytorch Metric Learningを用いてデザイン性の高い画像に関するカテゴリ分類を実践してみました.

Metric Learningの手法を用いたカテゴリ分類を通じて,クラス分類タスクとは異なりモデルを介した特徴ベクトルがどのような空間上に写像されるのかを見ることで,他ラベルとの関係なども理解することができて非常に面白い手法だと感じました.

ちなみに上手くいかなかったのですが,別のタスクとして歌手のアートワーク(CDジャケットやdiscography)を用いて歌手をイメージするデザインがうまく分類できないかも試してみました.上手くいかなかった要因としては歌手のアートワークは歌のコンセプトによって大幅に変わることと,歌手とは別のデザイナーがジャケットなどを作成している場合もあるため一貫性がなく傾向を掴みきれなかったのかと思います.(今回は少量のデータセットで実験したことも原因として考えられますのでもっとデータを増やして試してみたいですね)

2020年はTorchServeやPytorch Forecastingなどと言ったPytorchに関するエコシステムが大幅に増えた印象が強く,日々新しいものに触れる機会があって知識欲を掻き立てられる一年でした.有効かつ使いやすいツールが増えることで業務でもアプローチの幅が広がりますので,今後もさらに知見を深めて活用していきたいと思います.

引き続きNTTドコモ R&D 控え室 アドベントカレンダーとメインのNTTドコモ R&D アドベントカレンダーにてドコモ社員による多様な記事が公開されますのでお楽しみください.

参考資料


  1. その他にもSiamese Networkや損失関数で距離を表現する方法もあります 

87
70
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
87
70