SimSiam (Simple Siamese) とは類似度に基づいた自己教師あり学習の1種であり、2020年にFacebook AI Research (FAIR) によって提案された新しいアプローチです。 [1]
また、2021年にはCVPR2021において Best Paper Honorable Mention を獲得しており、かなり注目を浴びた論文と言えるでしょう。[2]
本記事ではPyTorchを使用して、SimSiam の簡単な実装および CIFAR-10 による原理検証を行います。
背景知識
SimSiam の立ち位置やその利点について理解するためには、以下のような知識を必要とします。まずはこれらの背景知識について説明します。
- 自己教師あり学習
- Siamese Network
- Collapsing Solutions
1. 自己教師あり学習
自己教師あり学習とは、ラベルの付与されていないデータを利用して、データ間の特徴や表現を学習する手法です。紹介する SimSiam も自己教師あり学習に該当 します。
自己教師あり学習は、主に何らかのタスクで利用するための事前学習手法として扱われています。ラベルのないデータから学習したのち (事前学習)、その学習したネットワークを別の教師あり学習タスクに利用する (転移学習) という流れです。
自己教師あり学習の強みは何といっても人間がラベル付けしたデータを必要としない点であり、低品質のデータであっても学習が可能になるためコストを抑えることができます。
さらに近年では教師あり学習にも迫る性能を見せており、注目が集まっています。
2. Siamese Network
図1:Siamese Network の例
Siamese Network (シャムネットワーク) はニューラルネットワークの1種であり、2つの画像入力を全く同じサブネットワーク (構造だけでなく重みも同じ) を通すことで得られる特徴ベクトルから、それらの特徴距離を計算する手法です。
つまり入力データを特徴ベクトルに変換して比較することにより、何らかの形でデータの「類似度」を学習しているため、画像分類タスクなどに利用されます。
また、入力の類似度を学習するという点から Siamese Network はラベルを必要としない自己教師あり学習との相性が良く、Siamese Network に基づいた自己教師あり学習用の様々なアーキテクチャが考案されています。
しかしながら Siamese Network を利用することによる問題も存在し、それが後述する Collapsing Solution
になります。
この問題を回避しつつ学習する手法として、様々なネットワーク構造やアルゴリズムが提案されており、SimSiam もその1つというわけです。
3. Collapsing Solution
図2:Collapsing Solution の例
Collapsing Solution とは、Siamese Network による特徴ベクトルの学習時に起こる典型的な問題のことで、具体的には 任意の画像入力に対して定数を返すようにネットワークが学習されてしまうこと を指します。
Collapsing Solution の例として示した図2では、入力として似た画像ペア $x_1, x_2$ (ポジティブペア) のみを用意し、それらの特徴ベクトル $p_1, p_2$ を近づけるように学習しています。このような設定では、$p_1, p_2$ を定数にしてしまえば特徴ベクトルがなす角度は常に0となるため、損失関数を容易に小さく (大きく) できてしまいます。
もちろんそのように学習が進んだとしても分類タスクなどには一切活かせないため、この自明の解の方向に学習が進まないように工夫する必要があります。
先行研究での Collapsing Solution 回避法
先行研究では主に以下の3種類の方針でこの問題を回避しています。
- ネガティブペアの用意
-
オンラインクラスタリング (ネットワーク学習時にクラスタリングを考慮)
- SwAV [5] など
- モーメンタムエンコーダの導入
一方で、これらの手法はどうしてもネットワーク構造が複雑になったり、バッチサイズが大きくなってしまうなど、別の問題を引き起こしてしまいます。
シンプルなネットワーク構造のまま Collapsing Solution
を回避できることを実証した画期的な手法が SimSiam なのです。
SimSiam とは
SimSiam の構造

図3:SimSiamの構造。画像は論文[1]からの引用です。
1枚の訓練画像 $x$ に異なる変形を施してできた画像 $x_1, x_2$ をそれぞれ同じネットワーク encoder $f$ に入力し特徴ベクトルへと変換します。特徴ベクトル $f(x_1)$ はその後 predictor $h$ を通り、最終的に $h(f(x_1))$ と $f(x_2)$ のコサイン類似度を計算します。
すなわち特徴空間において、元々は同じ画像であった$x_1$と$x_2$が近くなり、異なる画像の場合は遠くなるように学習しているわけです。
入力として対象画像 (ポジティブペア) のみから2つの画像を生成し活用するため、MoCoなどの対照学習とは異なり、対象と無関係の画像 (ネガティブペア) は用意しません。
SimSiam の利点
SimSiam の構造においても述べた通り、SimSiam ではネガティブペアを用意する必要がなく同一の画像に違う変形を施すことで類似度を計算するため、バッチサイズを抑えることができます。また、モーメンタムエンコーダなどの機構も必要としません。
勾配停止と推定器を導入するだけで、シンプルなアーキテクチャかつ比較的小さなモデルにもかかわらず Collapsing Solution
を回避できるため、効果的な学習が可能です。
このシンプルさはもちろんのこと、画像分類タスクにおいて高い精度を達成したという点もこの論文が大きな注目を浴びた要因の一つでしょう。
実際、下図のように論文 [1] において先行研究との比較がなされています。
図4:ImageNet 線形分類における SimSiam とその他の手法の比較。画像は論文[1]からの引用です。
これは ImageNet の線形分類の結果であり、224 $\times$ 224 の2つの画像を入力として ResNet-50 を各手法に基づいて事前学習させてから線形分類を行っています。
この結果から、SimSiam は非常にシンプルな構造ながら他のネットワークにも比類する (時には上回る) 性能であることが分かります。
SimSiam の実装
それでは実装に移っていきましょう。
実装が比較的簡単である点も SimSiam の利点の1つです。
前述の通り、SimSiam は Encoder
(Backbone + Projector) と Predictor
からなります。
Backboneは特徴ベクトルを出力するように学習させたいネットワークを示しており、SimSiam においては基本的に最終層を取り除いた ResNet が利用されます。
そのため 実装する必要があるのは Projector, Predictor, そして SimSiam 自体 となります。
Projector
class Projector(nn.Module):
def __init__(self, in_dim, out_dim=2048):
super(Projector, self).__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, in_dim, bias=False),
nn.BatchNorm1d(in_dim),
nn.ReLU(inplace=True),
nn.Linear(in_dim, in_dim, bias=False),
nn.BatchNorm1d(in_dim),
nn.ReLU(inplace=True),
nn.Linear(in_dim, out_dim, bias=False),
nn.BatchNorm1d(out_dim),
)
def forward(self, x):
return self.layers(x)
Projector
は 3層からなる多層パーセプトロンで、最終層だけ ReLU が入らない構成にしています。
Predictor
# 予測器の設定
class Predictor(nn.Module):
def __init__(self, in_dim=2048, pred_dim=512, out_dim=2048):
super(Predictor, self).__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, pred_dim, bias=False),
nn.BatchNorm1d(pred_dim),
nn.ReLU(inplace=True),
nn.Linear(pred_dim, out_dim)
)
def forward(self, x):
return self.layers(x)
Predicctor
は こちらも Projector
と同様に多層パーセプトロンで、最終層には BatchNorm1d 及び ReLU が入らない構成にしています。
BatchNorm1d を含めるべきかどうかなど、Projector
や Predicctor
の構成については論文中でも議論がなされていますが、ここでは論文で実際に使用された設定に従っています。
SimSiam
class SimSiam(nn.Module):
def __init__(self, backbone, projector, predictor):
super(SimSiam, self).__init__()
self.backbone = backbone
self.projector = projector
self.predictor = predictor
def forward(self, x1, x2):
# x1,x2は変形後の画像を表す
z1 = self.projector(self.backbone(x1).flatten(start_dim=1))
z2 = self.projector(self.backbone(x2).flatten(start_dim=1))
p1 = self.predictor(z1)
p2 = self.predictor(z2)
# .detach()で勾配停止させる
return p1, p2, z1.detach(), z2.detach()
SimSiam において注目すべきなのは、 forward
内の .detach()
です。
この処理により encoder からの直接出力 z1, z2 に対する勾配計算が停止するため、論文中の画像に書かれていた stop-grad
が実現します。
SimSiam の検証
実装したSimSiamの原理を検証するため、画像データとしてCIFAR-10を利用して訓練を行いました。
訓練データの用意
まずはCIFAR-10に異なる transform
を施して、各画像について2枚組の訓練データを用意する必要があります。
transform
は論文中の表記に基づいて、以下のように PyTorch の関数を利用して実装しています。なお、ガウシアンブラーは使用していません。
データ変形用 transform
# データセット変形用
transform = transforms.Compose([
transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
SimSiam用のデータセット作成
class SimSiamDataset(Dataset):
def __init__(self, root, transform1, transform2, train=True):
self.dataset = datasets.CIFAR10(root=root, train=train, download=True)
self.transform1 = transform1
self.transform2 = transform2
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image, label = self.dataset[idx]
image1 = self.transform1(image)
image2 = self.transform2(image)
return image1, image2, label
同じ画像から2つの変形画像を作る必要があるため、SimSiam 用のデータセットを定義しています。
data_loaderのインスタンス化
train_dataset = SimSiamDataset(
root='./data',
transform1=transform,
transform2=transform,
train=True
)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=2)
バッチサイズは論文 [1] 中の値を利用しています。
encoder の性能評価
100エポック学習させる過程で、encoderの性能を評価するため、10エポックごとに k近傍法 (kNN) を用いてラベル付けを行いました。k近傍法の細かい説明については他の記事を参照ください。
具体的には、以下のような流れです。ここでは k = 200 としています。
- 訓練に用いたデータと validation 用に作成したデータを backbone である ResNet-18 に入力し、それぞれ特徴ベクトルを得る
- 各 validation データの特徴ベクトルに対して、訓練データの特徴ベクトルとのコサイン類似度を計算し、類似度が-1に近い順に k 個の訓練データを抜き出す
- k 個の訓練データが属するラベルのうち、頻度が最も高いものを validation データのラベルと推定する
特徴ベクトルの計算 ・ cos類似度による kNN
import sklearn.metrics.pairwise as F
# 特徴ベクトルの計算
def feature_for_knn(model, data_loader, culculate_type):
# simsiamモデルを評価用に設定
model.eval()
features = []
labels = []
with torch.no_grad():
if culculate_type == "train":
for x, _, y in data_loader:
x = x.to(device)
feature = model.backbone(x).flatten(start_dim=1)
features.append(feature.cpu().numpy())
labels.append(y.cpu().numpy())
elif culculate_type == "val":
for x, y in data_loader:
x = x.to(device)
feature = model.backbone(x).flatten(start_dim=1)
features.append(feature.cpu().numpy())
labels.append(y.cpu().numpy())
features = np.concatenate(features, axis=0)
labels = np.concatenate(labels, axis=0)
return features, labels
def knn_cosine(train_features, train_labels, val_features, k):
# コサイン類似度の計算
cosine_sim = -F.cosine_similarity(val_features, train_features)
# valのラベルリスト
val_label_pred = []
for i in range(val_features.shape[0]):
# 上位k個の類似度が高い(-1に近い)インデックスを取得
top_k_indices = np.argsort(cosine_sim[i])[:k]
# 上位k個のラベルを取得
top_k_labels = train_labels[top_k_indices]
# 最頻出ラベルを取得
most_common_label = Counter(top_k_labels).most_common(1)[0][0]
val_label_pred.append(most_common_label)
val_label_pred = np.array(val_label_pred)
return val_label_pred
実行結果
特徴ベクトル計算用のモデルは backbone
、すなわち最終層を取り除いた ResNet-18 としています。
損失関数は p1 と z2 、p2 と z1 の負のコサイン類似度の平均としています。つまり、-1 に近づくほどうまく学習されているということです。
また各 epoch における loss の値に関しては、全訓練データのコサイン類似度の平均値としました。
図5:SimSiam のCIFAR-10による訓練時のグラフ。左縦軸 (青色) は各 epoch における loss を表し、右縦軸 (赤色) は 1 epoch ごとの kNN 分類精度を示しています。
epoch が進むごとに kNN 分類の精度が上がっている様子を赤色のグラフから確認できます。一方で loss に関しては一度大きく下がってからほとんど変動せず少し振動しており、学習率が高すぎた可能性があります。
学習率などのハイパーパラメータは論文 [1] の値をそのまま利用しましたが、loss に合わせて多少のチューニングが必要だったかもしれません。
まとめ
今回は PyTorch を用いて 自己教師あり学習の SimSiam を実装しました。
勾配停止させることで collapsing solution を回避して、効果的に学習させるアプローチは非常に面白い考えだなと思います。
実装したプログラムは以下のリンクから見ることができます。
以上です。