LoginSignup
2
9

【論文】自己教師あり学習「SimCLR」を理解、実装する

Last updated at Posted at 2022-07-23

はじめに

※このブログは東京大学大学院 情報理工学研究科の授業「映像メディア学」のレポートとして書かれたものです。

深層学習ではデータとラベルを用いてモデルを学習する、所謂データドリブンな手法が用いられています。しかし膨大なデータに対してラベルを付与する作業は非常にコストがかかるため、大規模データセットの作成は非常に困難であるのが現状です。そのため近年、ラベルを用いずにモデルの事前学習を行う自己教師あり学習、特に画像処理の分野では対照学習の研究に注目が集まっています。このブログでは対照学習手法の一つであるSimCLRの論文を読み、再現実装を通して理解を深めるものとなります。

論文読み

A Simple Framework for Contrastive Learning of Visual Representations [Ting Chen+ ICML20]
https://arxiv.org/abs/2002.05709

Introduction

人間の手を使わずに画像表現を学習することは長年の課題であり、既存のアプローチは基本的に生成的手法と識別的手法に分類することが出来る。

生成的手法はピクセルレベルの生成(VAEやGAN等)を通して表現の学習を行う。しかしこの手法は計算量が多く、表現学習には必要ない学習があるといった欠点がある。一方で識別的手法では教師あり学習で使用されるような目的関数を用いて表現の学習を行う。具体的にはラベルなしデータセットからデータとラベルを生成しpretext-taskを解くことで表現を学習する。多くの識別的手法ではヒューリスティックなpretext-taskに依存しており(ジグソーパズル、カラー化等)、これは表現の一般性を制限する可能性がある。識別的手法の中では、対照学習におけるアプローチは近年最先端の結果を達成しており注目を集めている。本研究では画像表現の対照学習のための簡単なフレームワークSimCLRを提案する。SimCLRは先行研究の性能を凌駕するだけでなく、特殊なアーキテクチャやメモリバンクを使用しない単純な構造となっている。

本研究では以下のことを示している。

  1. 表現を学習する上で、データ拡張の構成が重要である。対照学習は教師あり学習よりもデータ拡張の恩恵を強く受ける。
  2. 表現と対照損失の間に非線形変換を導入すると、得られる表現の質が向上する。
  3. 対照損失による表現学習は埋め込み正規化と温度パラメータの恩恵を受ける。
  4. 対照学習は教師あり学習と比較して、バッチサイズの大きさと学習ステップの多さによる恩恵が大きい。

これらの知見を組み合わせることで、ImageNetにおける自己教師あり学習において従来の手法を大幅に上回り、教師あり学習に匹敵する性能を獲得した。

Method

SimCLRは潜在空間における対照損失を用いて、同じデータにおける異なる拡張を施したビュー間の類似度を最大化するように表現を学習する。具体的には以下の構成要素からなる。

  • データ拡張
    与えられたデータをランダムに変形して、2つのビュー $\tilde{x_i}, \tilde{x_j}$を生成し、これらを正のペアとみなす。本研究では3つの単純な拡張を順次適用する。
    ・ランダムクロップ、元のサイズへのリサイズ
    ・ランダムな色歪み
    ・ランダムなガウスぼかし
    特にランダムクロップと色歪みの組み合わせは非常に重要である。
  • エンコーダ
    データから表現を抽出するニューラルネットワークベースのエンコーダを$f$とする。本研究では一般的に使用されるResNetを使用し、表現$h_i, h_j$を獲得する。
  • 投影ヘッド
    対照損失が適用される空間に表現を写像するための投影ヘッド$g$を導入する。具体的には一層の隠れ層を持つMLPを用いて$z_i = g(h_i) =W^{(2)}σ(W^{(1)}h_i)$を求める。ただし活性化関数$σ$はReLUを用いる。すなわち対照損失は $h_i,h_j\ $ 間ではなく $z_i,z_j\ $ 間で求める。
  • 対照損失
    バッチサイズNのミニバッチをランダムにサンプリングし、上記のデータ拡張によって2Nのデータを得る。正例 $x_i,x_j\ $ が与えられた時、バッチ内の他の2(N-1)個のデータは負例として扱う。損失関数は以下のように定義される(NT-Xent, "Normalized Temperature-scaled Cross Entropy")
l(i,j) = -log\frac{exp(sim(z_i,z_j)/τ)}{Σ^{2N}_{k=1} \boldsymbol{1_{k\neq i}} exp(sim(z_i,z_k)/τ)}
L = \frac{1}{2N} Σ^{N}_{k=1}[l(2k-1,2k)+l(2k,2k-1)]

学習後エンコーダネットワーク$f$のみを他タスクに使用する。本論文では教師なし事前学習は基本的にImageNetを使用している(CIFAR-10の実験は論文内Appendix B.9参照)。事前学習モデルを様々なデータセットで転移学習し評価を行う。またlinear evaluation protocolと呼ばれる手法を用いた評価も用いる。これはベースとなるネットワークを凍結した状態で線形分類器のみを学習し、得られた精度を表現の質の指標とするというものである。

Data Augmentation for Contrastive Representation Learning

この章ではデータ拡張手法による精度への影響の解析を行っている。データ拡張手法にはクロップ+リサイズ(+フリップ)、回転、カットアウトなど空間的変換や、色歪み、ガウスぼかし、ソーベルフィルタのような外観変換がある。※既存の対照学習の研究ではグローバルからローカルな部分の予測や近傍ビューの予測といったアプローチを使用していた。本論文ではランダムトリミング+リサイズが上記のアプローチを包含していると考えている。

一つの変換、または二つの変換の合成を用いて線形評価を行った結果、変換の合成によって表現の質が向上した。またランダムクロップと色歪みの組み合わせが非常に優秀であるという結果となった。ランダムクロップのみを用いる場合の重大な問題として画像から得られるパッチはほぼ同様の色分布となることである。よってランダムクロップのみを用いると画像の識別にカラーヒストグラム情報のみを使用する可能性がある。このことからランダムクロップと色歪みの組み合わせが重要であると考えられる。

また色歪みの強さを上げると自己教師あり学習におけるlinear evaluationが大幅に改善された。同様の拡張を教師ありモデルに対して適用しても精度は変化しない、または低下することが観測された。このことから対照学習は教師あり学習よりも強く(色の)データ拡張の恩恵を得ていることが読み取れる。

Architectures for Encoder and Head

基本的にモデルの深さと幅を大きくすればするほど精度は上昇する。これは教師あり学習でも同様の結果となっているが、教師ありモデルと教師なしモデル+線形分類器のギャップはモデルサイズが大きくなることに縮小している。このことから教師なし学習は教師あり学習と比較してモデルサイズの恩恵を多く受けることが示唆されている。

また投影ヘッドの重要性を調べるために以下の3つのアーキテクチャを使用した線形評価を行う。(1)identity mapping (つまり何もしない) (2) linear projection (3) ReLUの活性化関数を用いたnonlinear projection。結果として、nonlinear projectionはlinear projectionやidentity mappingより良い成績となった。

Loss Functions and Batch Size

NT-Xent損失とその他の損失(ロジスティック損失、マージン損失)を比較し、NT-Xentが優れていることが示された。損失におけるl2正規化と温度$τ$の重要性の検証も行っている。

またエポック数とバッチサイズの検証も行っている。エポック数が少ない場合(100程度)、バッチサイズは大きい方が明確に精度が良くなるが、エポック数が増加するにつれてこのギャップは減少していく。教師あり学習と異なり、対照学習ではバッチサイズが大きくなると、多くの負例が使用できるため収束が容易になり、また長く訓練を行う事でも同様に多くの負例を用いることが出来るため結果が良くなるのだと考えられる。

Comparison with State-of-the-art

ResNet50を異なる幅(*1,*2,*4)を用意し、1000エポックで学習を行った。

  • linear evaluation
    線形分類器のみを学習して評価。simCLRとResNet-50(隠れ層の幅4倍)の結果は他の自己教師あり学習よりも高い精度を達成し、教師あり事前学習済みResNet-50モデルと同等の精度となっている。
  • Semi-supervised learning
    ILSVRC-12学習データセットの1%-10%のラベル付きデータをサンプリングし、(scratch or ImageNetで自己教師あり学習した)ネットワークをfine-tuneする。他の半教師、自己教師あり学習の性能を上回っている(ただし幅4倍のモデルを使用しているため、単純な比較はできないと思われる)。

  • Transfer learning
    resnet-50*4を用いて他データセットを用いた転移学習性能を評価する。自己教師ありモデルは教師あり学習の精度とほぼ同等の性能となった。(appendixにおいてresnet50を用いた場合の実験が行われており、この場合教師あり学習が明確に優位であった点について触れられている。)

appendixにはデータ拡張の詳細、バッチサイズとエポック、モデルの幅の調査等が記載されている。

再現実装

以下CIFAR10を用いたSimCLRの事前学習を行う。具体的には論文内にもCIFAR10を用いた実験(B.9. CIFAR-10)が行われているので、その再現実装を行う。

データ拡張の詳細

入力画像
download.png

ランダムクロップ+リサイズ+ランダムフリップ

random_crop = nn.Sequential(transforms.RandomResizedCrop(32, scale=(0.08, 1.0), ratio=(3/4, 4/3)),
  transforms.RandomHorizontalFlip(p=0.5))
img = random_crop(img)

download.png

色歪み

def get_color_distortion(s=1.0):
    # s is the strength of color distortion.
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) 
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)

    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort
transform = get_color_distortion(s = 0.5)
img = transform(img)

download.png

実装

code

事前学習

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
import torch.nn.functional as F

N = 256 #バッチサイズ
tau = 0.1
lr = 0.5
num_epochs = 100

#データセット
trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=N,
                                          shuffle=True, num_workers=2)

#データ拡張
def get_color_distortion(s=1.0):
    # s is the strength of color distortion.
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) 
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)

    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

random_crop = nn.Sequential(transforms.RandomResizedCrop(32, scale=(0.08, 1.0), ratio=(3/4, 4/3)),
  transforms.RandomHorizontalFlip(p=0.5))

#モデル
model = models.resnet18()

model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 128))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=lr)

#学習
def calculate_loss(zi, zj):
  # 参考
  # https://theaisummer.com/simclr/
  batch_size = zi.shape[0]
  zi_norm = F.normalize(zi, dim=1)
  zj_norm = F.normalize(zj, dim=1)
  representation = torch.cat([zi_norm,zj_norm], dim=0) #[2*batch_size, h_dim]
  sim = torch.matmul(representation, torch.t(representation))

  sim_ij = torch.diag(sim, batch_size)
  sim_ji = torch.diag(sim, -batch_size)
  top = torch.exp(torch.cat([sim_ij,sim_ji],dim=0)/tau)
  mask = torch.ones((2*batch_size,2*batch_size), dtype=bool)
  for i in range(batch_size):
    mask[i,i] = False
  mask = mask.to(device)
  bot = mask * torch.exp(sim / tau) #要素積
  all_losses = - torch.log(top/torch.sum(bot, dim = 1))
  loss = torch.sum(all_losses)/ (2*batch_size)
  return loss

model.train()
loss_sum = 0
for epoch in range(num_epochs):
  loss_sum = 0
  for x, _ in trainloader:
    x = x.to(device)
    optimizer.zero_grad()

    color_distortion = get_color_distortion(s = 0.5)
    x1 = random_crop(color_distortion(x))
    x2 = random_crop(color_distortion(x))
    z1 = model(x1)
    z2 = model(x2)
    loss = calculate_loss(z1, z2)
    loss_sum += loss
    loss.backward()
    optimizer.step()
  print("epoch",epoch,"loss",(loss_sum/len(trainloader)).item())
torch.save(model.state_dict(), "./model.pth")

ラベルを用いて学習

import torch
import torch.nn as nn
from torchvision import datasets, models, transforms

N = 256 #バッチサイズ
lr = 0.1
num_epochs = 20

#データセット
trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transforms.ToTensor())
testset = datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=N,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=N,
                                         shuffle=False, num_workers=2)

#モデル
model = models.resnet18()
dim = model.fc.in_features

model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 128))
model.load_state_dict(torch.load("./model.pth"))

for param in model.parameters(): #パラメータを固定
    param.requires_grad = False

model.fc = nn.Linear(dim, 10)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)

#Loss
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=lr)

#学習

loss_sum = 0
for epoch in range(num_epochs):
  model.train()
  loss_sum = 0
  print("epoch",epoch)
  for x, labels in trainloader:
    x = x.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

  model.eval()
  if (epoch+1) %5 == 0:
    correct = 0
    for x, labels in testloader:
      x = x.to(device)
      labels = labels.to(device)
      outputs = model(x)
      _, preds = torch.max(outputs, 1)
      correct += torch.sum(preds == labels)
    print("accuracy",(correct/len(testset)).item())

実験では以下のような設定を用いた。
事前学習 100epoch
本学習 20epoch、
optimizer Adam
結果としては以下の表のようになった。この結果から実際に自己教師あり学習によって有効な表現を得ることが出来ることが確認できた。

特徴抽出器 fine-tune
ランダムに初期化したモデル 25.1[%] 74.3[%]
自己教師あり学習済みモデル 63.8[%] 76.6[%]

損失の実装において自分が作成したコードでは学習が非常に遅かったため、以下のサイトを参考して実装した。https://theaisummer.com/simclr/
自身で作成したコードは以下の通りである。

code
def calculate_loss(zis, zjs):
  N = zi.shape[0]
  loss = 0
  z_norm = F.normalize(torch.cat((zis,zjs),dim=0), dim=1)
  for i in range(N):
    top = torch.exp(torch.dot(z_norm[i], z_norm[i+N])/tau)
    bot1 = -torch.exp(torch.dot(z_norm[i], z_norm[i])/tau)
    bot2 = -torch.exp(torch.dot(z_norm[i+N], z_norm[i+N])/tau)
    for j in range(2*N):
      bot1 += torch.exp(torch.dot(z_norm[i], z_norm[j])/tau)
      bot2 += torch.exp(torch.dot(z_norm[i+N], z_norm[j])/tau)
    loss -= (torch.log(top/bot1) + torch.log(top/bot2))
    loss /= 2*N
  return loss

終わりに

SimCLRの論文読みおよび再現実装を行い、実際にSimCLRの性能を確認することが出来た。また実装では並列計算を意識したコーディングを行わないと計算スピードが落ちることを実感した。この調子で自己教師あり学習の研究を追っていきたい。

2
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
9