20
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

SimSiamで自己教師あり学習(Self-Supervised Learning)に挑戦

Last updated at Posted at 2021-07-30

#はじめに

自己教師あり学習(Self-Supervised Learning)の一つであるSimSiamを実装し、CIFAR10の学習で評価を試みる。

#SimSiamとは

論文はこちら。

Facebook AI Researchからの論文で、有名なKaiming He氏が共同執筆者となっている。
SimSiamはSimple Siameseの略。Siamはシャム猫のシャムでタイ王国のことだが、Siameseはシャム双生児(現在は結合双生児と呼ぶことが多いはず)のことなのだろう。

日本語の解説はWeb上にいくつかあるが、こちらのスライドが簡潔で分かりやすいと思う。

訓練画像1枚からランダムで加工した2枚の画像を生成して(ここはデータ拡張と同じ)、表現の類似度を損失として学習させる。「元が同じである2枚の加工済み画像に対して、同じような表現を出力できるモデルは、何らかの特徴を学習したのだろう」ということなんだと思われる。全ての画像に一律に同じ表現を出力しても類似度は一致してしまうが当然そんなモデルに意味はない。このような状態を"Collapsing Solutions"と呼び、これを防ぐのが工夫のしどころのようだ。
SimSiamは既存の技術に比べて、以下のようなメリットがあるようだ。

  • 非類似度の学習(Negative Sample)が不要
  • バッチサイズが大きくなくても良い
  • 実装が比較的簡単

論文にある図を添付して構造を示す。

スクリーンショット 2021-07-22 19.48.00.png

上の図の用語の解説としては、大体こんな感じで良いだろう。

用語 意味
Backbone 表現学習(Representation Learning)させたいモデル
Projector Backboneの出力を変換させる層
Predictor 対となる画像からのProjectorの出力を予想する層
Encoder Backbone+Projector

Predictorでペア画像に対するEncoderの出力を予測することになるが、学習が進めば平均的な出力を予想することになり、結果Encoderの出力も平均的な出力に近づいているので、結果的にBackboneも平均的(一般的)な特徴を学習したことになる。というような流れだと筆者は理解している。
図中のstop-gradは勾配計算を止めることで、これで"Collapsing Solutions"を防ぐらしい。
素人考えではProjectorの層を無くしてBackboneとPredictorを直結した方が早いんじゃないかと思うのだが、Projectorを入れてLoss計算用の空間に一旦投影する。この辺は先行研究であるSimCLRの論文で議論されているようだ。

#実装

公式実装がありPyTorch派の方はこちらを使えばそれで終わりのように思うが、筆者はtf.keras派なので自前で実装しなければならない。
Kerasの実装は既にいくつかあるようで、Kerasの公式のサイトにも実はCIFAR10を学習させるコード例があるのだが、筆者の見解ではこのコードには問題があり正しく学習していないように思われる。一方、TensorFlow向けのこちらのコードはうまく学習しているように思うので、公式実装と後者のコードを主に参考にする。

##コード

###モデル

tf.kerasの実装としてはカスタム訓練ループが必要なので、SimSiamという名前のモデルとして定義する。

from tensorflow.keras import layers
from tensorflow.keras import regularizers


class SimSiam(tf.keras.Model):
    def __init__(self, backbone, project_dim = 2048, hidden_dim=512, l2_reg=1e-4):
        super(SimSiam, self).__init__()
        self.backbone=backbone
        self.model, self.projector, self.predictor = self.make_model(backbone, project_dim, hidden_dim, l2_reg)
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")

    def make_projector(self,prev_dim, num_layers, project_dim, l2_reg):
        dense_params = {'use_bias':False,
                        'kernel_initializer': tf.keras.initializers.VarianceScaling(scale=1.0/3, mode='fan_in', distribution='uniform'),
                        'kernel_regularizer':regularizers.l2(l2_reg)}
        inputs = layers.Input((prev_dim,))
        x = inputs
        if num_layers!=0:
            for _ in range(num_layers-1):
                x = layers.Dense(prev_dim, **dense_params)(x)
                x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
                x = layers.ReLU()(x)
            x = layers.Dense(project_dim, **dense_params)(x)
            x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
        return tf.keras.Model(inputs, x, name="Projector")
    
    def make_predictor(self,project_dim, hidden_dim, l2_reg):
        dense_params = {
                        'kernel_initializer': tf.keras.initializers.VarianceScaling(scale=1.0/3, mode='fan_in', distribution='uniform'),
                        'bias_initializer': tf.keras.initializers.VarianceScaling(scale=1.0/3, mode='fan_in', distribution='uniform'),
                        'kernel_regularizer':regularizers.l2(l2_reg)}
        inputs = layers.Input((project_dim,))
        x = layers.Dense(hidden_dim, use_bias=False, **dense_params)(inputs)
        x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
        x = layers.ReLU()(x)
        x = layers.Dense(project_dim, **dense_params)(x)
        return tf.keras.Model(inputs, x, name="Predictor")
            
    def make_model(self, backbone, project_dim, hidden_dim, l2_reg):
        inputs = layers.Input((None, None, 3))
        outputs = []

        x = backbone(inputs)

        prev_dim = x.shape[-1]
        projector = self.make_projector(prev_dim, 2, project_dim,l2_reg)
        x = projector(x)
        outputs.append(x)

        predictor = self.make_predictor(project_dim, hidden_dim, l2_reg)
        x = predictor(x)
        outputs.append(x)

        total_model = tf.keras.Model(inputs, outputs, name="SimSiam")
        return total_model, projector, predictor

    def get_encoder(self):
        encoder = tf.keras.models.Sequential([self.backbone,self.projector], name="Encoder")
        return encoder

    @property
    def metrics(self):
        return [self.loss_tracker]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z1, p1 = self.model(data[0],training=True)
            z2, p2 = self.model(data[1],training=True)
            loss1 = tf.keras.losses.cosine_similarity(p1, tf.stop_gradient(z2)) 
            loss2 = tf.keras.losses.cosine_similarity(p2, tf.stop_gradient(z1))
            loss_simsiam = tf.reduce_mean((loss1+loss2)/2)

            loss_decay = sum(self.model.losses)
            loss = loss_simsiam + loss_decay

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        self.loss_tracker.update_state(loss)
        output_std = tf.reduce_mean(tf.math.reduce_std(tf.math.l2_normalize(tf.concat((p1,p2), axis=0), axis=-1), axis=0))
        results = {
            'loss': self.loss_tracker.result(), 
            'loss_simsiam': loss_simsiam,
            'output_std': output_std}
        return results

backboneとして何らかのモデルを与えると、それにProjectorとPredictorを連結したモデルを内部的に作成する。このモデルは出力として、ProjectorとPredictorの出力を同時に持つ。

train_step()の中身がSimSiamの実装の肝に当たる。tf.stop_gradientで勾配停止させて、tf.keras.losses.cosine_similarityを使って類似度を計算する。この計算は2枚それぞれについて実行できて、損失としてはその平均とする。
この処理で出てくるloss_simsiamが純粋に類似度での損失。tf.keras独特の処理で、損失はl2_regularizerとしてモデルに設定したWeightDecayを足さなければいけないようで、最終的なLoss表示はこれと一致しない場合がある。本記事の実装では、l2_regularizerは使わずOptimizer内で処理することにしてあるので、ほぼ一致する。
"output_std"はPredictorの出力の標準偏差で、ここは論文によると序盤は$1/\sqrt{dim}$付近が正常のようで、デフォルト値のdim=2048では0.022あたり。学習の進捗はLossの低下でわかるのだが、SimSiamでは意味のない学習でも最低値の-1.0に低下してしまうため、このような指標が必要になっていると思われる。学習に失敗すると、この数値は一気にほぼ0になる。

以上を見ると、実装はそんなに簡単ではないのでは?という疑問も生まれるが、他の手法よりは簡単ということなのだろう。実際SimCLRよりはかなり単純になっている。

###データセット

ペアのデータを作成するコードのみ抜粋しておく。

        ds_train_1 = tf.data.Dataset.from_tensor_slices(x_train)
        ds_train_2 = tf.data.Dataset.from_tensor_slices(x_train)
        ds_train_1 = ds_train_1.shuffle(train_len, seed=1).batch(batch_size,drop_remainder=True)
        ds_train_2 = ds_train_2.shuffle(train_len, seed=1).batch(batch_size,drop_remainder=True)
        ds_train_1 = ds_train_1.map(lambda image: data_augmentation(image, None,True), num_parallel_calls=tf.data.AUTOTUNE)
        ds_train_2 = ds_train_2.map(lambda image: data_augmentation(image, None,True), num_parallel_calls=tf.data.AUTOTUNE)
        ds_train = tf.data.Dataset.zip((ds_train_1, ds_train_2))
        return ds_train

同一のデータをもとに2つデータセットを作り、shuffleでseedを同じにして順番一致させ、それぞれランダムに変更を加えてから、zipを使って最後にペアにしている。ラベルは使わない。

#実験

論文にResNet18をCIFAR10で訓練した結果が掲載されているので、同様な実験を行い再現できるか確認する。

環境は、GoogleColabでTensorFlow2.5.0。

ステム部分をCIFAR用に変更したResNet18をBackboneとして、SimSiamで800エポック学習させる。
OptimizerはDecupled Weight DecayタイプのSGD(momentum=0.9)を自作して使用している。tf.kerasのl2_regularizerでは明示的に指定した層にしかWeight Decayがかからないが、この実装では全ての層にかかるようにしてある。
論文ではResNetは古いタイプのものを使っているようだが、ここではPreActタイプのV2で実装した。

10エポックごとに近傍法でValidationDataをクラス分けしてAccuracyとして出すように実装してある。(関連記事)
データ拡張の部分は論文とほぼ同じ処理に加えて、Cutoutも加えている。

800エポックとエポック数が大きいが、ResNet18ならGoogleColabのTPU使用すれば一エポック10秒程度なのでそれほど時間はかからない。

以下、実際の学習曲線。

SimSaim.png

Lossが序盤に一気に低下してから一旦戻って、そこからゆっくりと再度低下していく。この序盤の挙動はPreActタイプではないResNetでは発生しないようだ。Backboneのモデルの構成によって、細かい学習のチューニングが必要になる可能性がある。

Accuracyの最高値は89.5%で、Backboneからの特徴量だけで9割分類できたことになる。
BackboneにClassifierとして全結合層をつけ、Backboneをフリーズした状態で学習をさせると、91.1%の正解率になったが、これは論文の数値(91.8%)とほぼ同等の結果と言っていいだろう。ちなみに同じモデルを普通に学習させると95%程度は可能なので、別に高性能なわけではない。

BackboneにValidationデータを入力した場合の出力をtSNEで可視化したものがこちら。

resnet18v2_800.png

DogとCatの分類に苦労しているようだが、ラベルなしでかなりクラスター化できているように見える。

実験用Google Colabノート

#まとめ

SimSiamをtf.kerasで実装し、CIFAR10の学習で論文とほぼ同じ結果が得られた。
ラベルなし学習が必要あるいは有効と思われるケースでは、使用を検討してみても良いのではなかろうか。

##関連記事

本記事執筆後に、Qiita上により詳細な記事が作成されたので、より詳しく知りたい方はそちらも参照すると良いだろう。

以下は、"Barlow Twins"というさらに新しい自己教師あり学習手法の解説記事。

20
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?