0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

WandBのSweepsやってみた(ViT + PyTorchLightning + Transformer-Explainability)※精度は上がってないよ

Posted at

この記事を書いたのは2023年の為、現在とはWandB(Weights & Biases)の仕様が変わっている箇所もあると思いますのでご容赦ください。

STL10データセットの画像分類モデルをWandBのSweepsを使ってパラメータチューニングしてみた結果をまとめました。

概要

実施内容

  • STL10データセットによる画像分類をVision Transformer(Pytorch Lightning使用)で実装しました
  • コードはこちら(github)
  • hila-cheferさんのTransformer-Explainabilityも実装してみました
  • STL10データセットは犬、猫、飛行機、車のように10クラスからなる、教師データ5000枚、テストデータ8000枚で構成されています。テストデータの方が多く、モデルの汎化性能が問われます。

image.png

実行環境

  • Google Colab pro(GPU:T4(約15GB)、メモリ:約50GB(ハイメモリ))
  • テストデータでの検証のみ Colab pro+ 使用(GPU:A100(約40GB)、メモリ:約80GB)

参考(WandBコース)

本レポートはWandBのコースであるEffective MLOps: Model Developmentのlesson2(beyond baseline model)の課題提出した際に作ったレポートを元にしています。

WandBの使い方が動画で分かり易く学べますし(日本語字幕あり)、fastaiを使用して、非常にシンプルにHugging Face等の複数の学習済みモデルを検証する方法も学べてとても有用でした。さらにlesson3からはもっと面白そうでした・・・。海外のとても優秀なMLエンジニアの話が聞けるので非常にお勧めです。

lesson1での提出レポート:STL10-EDAVisionTransformer-STL10-Baseline

結果

  • Sweepsではvalidation accuracyがbaselineより最大で15%程上昇
  • テストデータでは特に改善せず

所感

  • 自分ではすぐには思いつけないパラーメータの組み合わせで、精度が上がっている様子には驚きました
  • テストデータでの精度は上がりませんでしたが、それは学習データの量や学習のさせ方に問題があると思われます(そもそもViTで学習済みモデル使わないとか、実践ではあまりやらないですかね)

Sweeps実行

条件

以下のように簡単に条件が指定できます。
今回メソッドはベイズ、最大試行回数は50回でやりました(実際は朝起きて止めたので、47回で終了)。

sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': '01_loss/valid',  
        'goal': 'minimize' 
    },
    'parameters': {
        'lr': {
            'min': 1e-6,
            'max': 5e-4,
            'distribution': 'uniform'
        },
        'batch_size': {
            'values': [32, 64, 128]
        },
        'dim_hidden': {
            'values': [256, 512, 768, 1024] 
        },
        'num_layers': {
            'values': [3, 6, 9, 12] 
        },
        'patch_size': {
            'values': [4, 8, 16, 32] 
        },
        'num_heads': {
            'values': [4, 8, 16, 32] 
        }
    },
    'early_terminate': {
        'type': 'hyperband', 
        'min_iter': 10,       
        'eta': 3            
    },
    'run_cap': 50 
}

以下コードでColabでも問題なく実行できました。

レッスン動画ではモデル等のコードは全部pyファイルにして、Sweepsの実行のみColabでやってましたが、自分はpyファイル化は一部のみで、メインとなるモデル部分等はnotebookのままやりました。ただ、Colabだとメモリ消去が上手くできない気がするので、本来は学習コードはtrain.py等にするのが良いのではと思います。

皆様既にご存じと思いますが、こういったコードはだいたいchat-GPT先生が教えてくれます。WandBのコミュニティチャンネルでもGPTが使えるようになっていて、そこで聞くこともできます(さらにColab内でも生成AI機能がつきましたが、そこは私はまだあまり使ってません)。

def train_wrapper():
    with wandb.init() as run:
        wandb_config = run.config


        # 既存のtrain_configのコピーを作成
        train_config_updated = train_config.copy()
        # Sweepで変更されるすべてのパラメータを更新
        train_config_updated['lr'] = wandb_config.lr
        train_config_updated['batch_size'] = wandb_config.batch_size
        train_config_updated['dim_hidden'] = wandb_config.dim_hidden
        train_config_updated['num_layers'] = wandb_config.num_layers
        train_config_updated['patch_size'] = wandb_config.patch_size
        train_config_updated['num_heads'] = wandb_config.num_heads


        # 更新された設定でトレーニング関数を実行
        train(train_config_updated)
# Sweepの作成
sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT)
# Sweepエージェントの実行
wandb.agent(sweep_id, train_wrapper)

結果

lossとaccuracyはPyTorch Lightninのメソッドであるself.logを使えば簡単にWandBに記録できます(これも勿論chat-GPT)。

今回PyTorch Lightningも初めて使いましたが、各ステップやエポック終了時の処理を分かりやすく記載でき非常に便利だと思いました。WandB公式にも、PyTorch Lightning + WandBのチュートリアルがありとても参考になりました。特にnotebookのサンプルがあるのが助かりました。

公式チュートリアル:PyTorch LightningとWeights & Biasesを使った画像分類

def training_step(self, batch, batch_idx):
    images, target = batch
    preds = self.forward(images)
    loss = self.loss_fn(preds, target)
    self.accuracy(preds, target)
    current_accuracy = self.accuracy.compute()
    self.log("01_loss/train", loss, prog_bar=True, logger=True, on_epoch=True, on_step=False)
    self.log("02_metrics/accuracy_train", current_accuracy, prog_bar=False, logger=True, on_epoch=True, on_step=False)
    return {"loss": loss}


def validation_step(self, batch, batch_idx):
    images, target = batch
    preds = self.forward(images)
    loss = self.loss_fn(preds, target)
    self.accuracy(preds, target)
    self.log("01_loss/valid", loss, prog_bar=True, logger=True, on_epoch=True, on_step=False)
    self.log("02_metrics/accuracy_valid", self.accuracy, prog_bar=False, logger=True, on_epoch=True, on_step=False)
  return {"valid_loss": loss}

以下はSweepsの可視化結果です。

image.png

下のウネウネしている図がとても良く、どういった組み合わせでどういう結果がでるのか傾向等読み取れます。今回だとdim_hidden(マルチヘッドアテンションの隠れ層の次元)が大きく、patch_sizeが小さいと精度が高くなりやすそうです

評価指標に対する変数の重要度と相関も可視化されており、ウネウネで見て取れるように隠れ層とパッチサイズについて、それぞれ正の相関と負の相関がありそうです

こういう可視化を自分でやると実装自体が手間ですし、モデルの条件毎の比較やチームでの共有も手間がかかります。初めてWandBを見たとき、「まさにこういうのが欲しかった!」と普通に感動しました。

image.png

今回サーチ手法をベイズで実施しましたが、ベイズだと最初の数回で精度が良いパラメーターを中心に検証するようで、検証パラメーターに偏りがあるのかもしれないと思いました(検証回数が多ければ良いのでしょうが、GPUリソースや時間がその分かかります)。検証したいパラメーターや範囲が多いときは、最初はランダムでやって、その結果を受けてある程度範囲を絞ってから、ベイズで検証するのでも良いのかと思いました。

このウネウネは見てるだけで面白いですが、 パラメータチューニングは基本的に微調整フェーズだと思うので、EDAやゴール設計等をしっかりやっていない段階でやるのは注意が必要だとも思います。

Google ColabのT4で実行し、47回完了するのに約10時間、約20リソースくらい消費しました(T4だと1時間に2くらい減ります)。

テストデータによる検証

条件

検証時のモデルのパラメーターは下の対象モデルをクリックすると情報が確認できます(こういう機能も素晴らしい)。

テストでの検証はTransformer-Explainability(勾配情報の可視化)も使いたかったので、Colabでもメモリ不足にならないパラーメータの組み合わせを選びました。(それでもColab pro+でA100を使わなければできませんでした。メモリ効率は全然考慮してないので、効率悪いと思いますが、計算量多い処理を何も考えずに組み込むと後で苦しむことを学びました・・・)

結果

1. Accuracy

以下棒グラフは、最上段のOverallが全クラスでのAccuracy、他はクラス毎のAccuracyです。

validationでは大きく精度が上がりましたが、テストデータでは特に改善は見られませんでした。STL10データセットは教師データ5000枚、テストデータ8000枚とテストデータの方が多いです。モデルが汎用的に学習出来ていないと、学習時は精度が上がっているように見えても、テストデータでの精度が上がらないように思います。

image.png

棒グラフは以下コードで作成しています。

test_stepでstep(バッチサイズ)毎に予測結果を計算し、on_test_epoch_endでテスト終了時に全体のAccuracyとクラス毎のAccuracyをwandb.logで棒グラフとして保存します。

def test_step(self, batch, batch_idx):
    images, targets = batch
    preds = self.forward(images)
    loss = self.loss_fn(preds, targets)
    self.accuracy(preds, targets)
    self.log("01_loss/test", loss, prog_bar=True, logger=True, on_epoch=True, on_step=False)
    self.log("02_metrics/accuracy_test", self.accuracy, prog_bar=False, logger=True, on_epoch=True, on_step=False)


    pred_labels = preds.argmax(dim=1)
    class_probabilities = preds.softmax(dim=1)


    for i in range(images.size(0)):
        self.test_predictions.append({
            "images": images[i],
            "pred_labels": pred_labels[i],
            "true_labels": targets[i],
            "class_probabilities": class_probabilities[i],
        })


    self.accuracy(preds, targets)


    for i in range(self.num_classes):
        class_mask = targets == i
        class_preds = preds[class_mask]
        class_targets = targets[class_mask]
        if class_targets.nelement() != 0:
            self.class_accuracy[i](class_preds, class_targets)


    return {"test_loss": loss}

def on_test_epoch_end(self):
    overall_accuracy = self.accuracy.compute()
    class_accuracies = [self.class_accuracy[i].compute() for i in range(self.num_classes)]
    data = [["Overall", overall_accuracy]] + [[self.class_names[i], class_accuracies[i]] for i in range(len(self.class_names))]
    accuracy_table = wandb.Table(data=data, columns=["Class", "Accuracy"])
    wandb.log({"Class-wise Test Accuracies": wandb.plot.bar(accuracy_table, "Class", "Accuracy", title="Class-wise Test Accuracies")})

valid accとtest accで10%以上差があり、全然ロバストでないことがわかります。

image.png

2. t-SNEによるプロット

精度が高めのairplaneやship等は比較的分かれていて、精度が低めの動物クラスはあまり分けられていないのがわかります。

validation終了時に実行しています(20エポック毎)。

image.png

3. Transformer-Explainability

Transformer-Explainabilityはhila-cheferさんが公開してくれているものを使用。可視化する仕組みは元のコードそのままでだいたいいけますが、自分で使っているViTに対応できるようにする必要があります。アテンションマップや勾配を保存するコードを加える+その他ちょこちょこしたら動作しました。

以下コードの構成です(間違っていたらすみません)。

推論時は、save_attention_mapにアテンションを格納し、save_attn_gradientsにはregister_hook関数を使って、勾配情報を格納します。

image.png

可視化時には、get_attn_gradientsとget_attention_mapで保存したアテンションマップと勾配情報を呼び出し、これを元に可視化をします。

image.png

参考までにマルチヘッドアテンションのコードを添付します。

class SelfAttention(nn.Module):


    def __init__(self, dim_hidden: int, num_heads: int,
                 qkv_bias: bool=False):
        super().__init__()


        assert dim_hidden % num_heads  == 0
        self.num_heads = num_heads
        dim_head = dim_hidden // num_heads
        self.scale = dim_head ** -0.5 


        self.proj_in = nn.Linear(
            dim_hidden, dim_hidden * 3, bias=qkv_bias) 


        self.proj_out = nn.Linear(dim_hidden, dim_hidden)


        # アテンションの勾配とマップ(アテンションの値(類似度))格納用
        self.attn_gradients = None
        self.attention_map = None


    # アテンションの勾配を保存
    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients


    # アテンションの勾配を取得
    def get_attn_gradients(self):
        return self.attn_gradients


    # アテンションマップを保存
    def save_attention_map(self, attention_map):
        self.attention_map = attention_map


    # アテンションマップを取得
    def get_attention_map(self):
        return self.attention_map




    def forward(self, x: torch.Tensor, register_hook=False):


        bs, ns = x.shape[:2]


        qkv = self.proj_in(x) 


        qkv = qkv.view(
            bs, ns, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)


        q, k, v = qkv.unbind(0)


        attn = q.matmul(k.transpose(-2, -1))
        attn = (attn * self.scale).softmax(dim=-1)


        # アテンションマップを保存
        self.save_attention_map(attn)


        # アテンションの勾配を保存するフックを登録
        if register_hook:
            attn.register_hook(self.save_attn_gradients)


        x = attn.matmul(v)


        x = x.permute(0, 2, 1, 3).flatten(2)
        x = self.proj_out(x) 


        return x

テスト時のモデルで実行した結果が以下です。validation時に20epoch毎に32枚のサンプル画像で可視化しています。

学習が始まるとすぐに、モデルは画像のごく一部の特徴しか認識しなくなっているように見えます。そして、学習が進んでもほとんど変化がほとんどありません。valid lossが20epochくらいで停滞していることと辻褄が合うように思えます。勾配が消失している可能性があるのでしょうか。

いずれにしても、このような認識をしていては、テストデータで精度が低いのもうなずけます。(又は、私のTransformer-Explainabilityの実装に問題があるのかもしれません)

Transformer-Explainabilityについては、今回はいまいち上手くいっているのかわかりませんが、エポックが進むごとに、モデルの認識がどのように変わっているのか分かり、チューニング時の仮説立てに有効だと思いました。(ただ計算量が多いので、常にやるようなものでもないかもしれません。特に学習完了後だけでなく、エポックによる変化を見ようとして、何度もメモリ不足に陥りました・・・)

  • 初期ステップ
    image.png

  • 学習途中
    image.png

t-SNEとTransformer-ExplainabilityはPyTorchLightningのon_validation_epoch_endメソッドでvalidation終了時に実行しています。

def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.epoch_interval == 0:
            model = pl_module
            dataloader = trainer.val_dataloaders


            """サンプル画像と正解ラベルを取得"""
            batch = next(iter(dataloader))
            x, y = batch


            device = model.device
            x, y = x.to(device), y.to(device)
            x.requires_grad_()  # 勾配を追跡


            # サンプル数に応じてデータを取得
            images = [img for img in x[:self.num_samples]]


            def denormalize(image, mean, std):
                mean = torch.tensor(mean).view(3, 1, 1).to(image.device)
                std = torch.tensor(std).view(3, 1, 1).to(image.device)
                return image * std + mean


            # 元画像の逆正規化
            images = [denormalize(img, self.channel_mean, self.channel_std) for img in images]


            """t-SNEとTransformer-Explainability用にモデルを保存してロード"""
            vee_model = copy.deepcopy(pl_module)
            te_images_with_captions = []


            """Transformer-Explainability"""
            # 予測値の取得
            preds = [vee_model(img.unsqueeze(0).cuda()) for img in images]


            with torch.set_grad_enabled(True):
                te = Transformer_Explainability.Transformer_Explainability(model=vee_model, cls_to_idx=STL10_CLASSES)


                for original_img, pred in zip(images, preds):
                    te_image = te.generate_visualization(original_img)
                    te_image_tensor = torch.tensor(te_image).to(original_img.device).float() / 255.0
                    te_image_tensor = te_image_tensor.permute(2, 0, 1)


                    # 元の画像とte_imageを横に並べて結合
                    combined_image = torch.cat((original_img, te_image_tensor), dim=2)  # dim=2は横方向に結合するため


                    # wandbはnumpy配列を受け取るので、テンソルをnumpy配列に変換
                    combined_image_np = combined_image.detach().cpu().numpy().transpose(1, 2, 0)  # CHW -> HWC


                    # 予測結果の文字列取得
                    top_str = te.print_top_classes(pred)
                    # wandb保存用のリスト
                    te_images_with_captions.append((combined_image_np, top_str))


            # 画像とそのキャプションを組み合わせてwandbにログ
            wandb.log({"te_images": [wandb.Image(image, caption=caption) for image, caption in te_images_with_captions]})


            """t-SNE"""
            util.plot_t_sne(self.data_module.test_dataloader(),
                            vee_model,
                            self.tsne_samples,
                            vee_model.device,
                            self.tsne_img_path)

            # t-SNEプロットをログ
            wandb.log({"t-SNE Plot (Valid End)": [wandb.Image(plt.imread(self.tsne_img_path))]})

4. 勾配情報

勾配消失の可能性がある?と思ったので、勾配情報のグラフも保存されるようにしました。

以下コードでlog="all"とするとメトリクスやパラメータだけでなく、gradientsも保存されるようです。とても簡単。全ネットワークの情報が保存されるので、グラフが大量に保存されます。

可視化しただけで、見方は良く分かっておらず、例えば以下の全結合層はずっと0近辺の値が多いようですが、それが勾配消失かまではわかっていません。

wandb_logger.watch(model, log="all", log_freq=config["log_freq"])

image.png

最後に

  • 今回、何となくViTをスクラッチで書くことから始め(本写経してるだけ)、PyTorch Lightning便利そうだ思って実装し、WandBは元々少し使ってましたが、さらに便利だと思い、また実装し、本で読んだTransformer-Explainabilityも面白そうだと思って実装し・・・と気づいたらどんどん拡散していました・・・

  • そんなに機械学習していない自分が言うのも何ですが、WandBは機械学習の実務をやっている人の気持ちをとてもわかっていると思います。OpenAIが全ての機械学習をWandBで回しているというのも頷けます。モデルの条件、メトリクスの可視化、考察等、整理してやらないと前と同じ実験してたり・・・でもこういう可視化や整理って結構手間なのでWandBは非常に有用だと思います。個人だと無料で使えるとは恐るべしです!

  • 文句を一つ言うなら、レポートの動作が少し不安定ですね。連打してたらクラッシュして、書いていた内容が吹き飛ぶということが2,3回起きました。15分に一回保存することをお勧めします(Ctrl+Z 連打するとクラッシュするかも、いやとにかく連打がダメか・・・)。

  • コードはこちら(github)

参照

追記(ミス補足)

レポート書いてから、とんでもない凡ミスに気づきました。

上で載せている結果は、以下のtrain_wrapperのbatch_size以降がコメントアウトされている状態で実行した結果でした・・・よって変わっているのはlrとbatch_sizeのみで、それ以降はWandBのログは変わってるけど、実施はデフォルトのままです。

sweep_configでパラメータの探索範囲を指定しても、train_wrapperで反映していないと実際の学習には反映されません。しかも、WandB上では反映されているように見えるから気付かなかった・・・(アテンションの隠れ層の次元:1024, 層の数:12とかメモリ不足にならないんだ~とか思ってましたが、T4では出来る訳ないのに不覚でした)

def train_wrapper():
    with wandb.init() as run:
        wandb_config = run.config


        # 既存のtrain_configのコピーを作成
        train_config_updated = train_config.copy()
        # Sweepで変更されるすべてのパラメータを更新
        train_config_updated['lr'] = wandb_config.lr
        train_config_updated['batch_size'] = wandb_config.batch_size
        train_config_updated['dim_hidden'] = wandb_config.dim_hidden
        train_config_updated['num_layers'] = wandb_config.num_layers
        train_config_updated['patch_size'] = wandb_config.patch_size
        train_config_updated['num_heads'] = wandb_config.num_heads


        # 更新された設定でトレーニング関数を実行
        train(train_config_updated)

以下はちゃんと全部反映したバージョンです。メモリ不足になりまくったので、最初よりだいぶパラメータの範囲を狭くしました。そしてA100でやりました。

結果については、validでは良くてもテストデータではあんまり変わらないのは同じだったので、全体的な傾向は同じだろうと思い、上の間違った結果もそのまま載せています。

image.png

image.png

今回は9回しかやっていません。というか途中で理由は不明ですが止まってました。手法をベイズではなくランダムでやりましたが、探索範囲が広い場合はやはりランダムの方が変に絞らず幅広く探索するというメリットもあるかなと思いました。

9回だけでも傾向の推測には役立ちそうで、やはりSweepsは便利だと思います。

image.png

以上 長文お読みいただきありがとうございました。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?