2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PytorchLightningでViTを動かす

Posted at

PytorchLightningが便利かもと思い、調度触りたかったViTと共に実装してみた。
(ViTは書籍コードを移植)

所感

PytorchLightning

まだ実装して数回学習回しただけだが、結構良さそう。

  • メリット
    • train, valid, testの各ステップ毎に処理が書けるので楽
    • 評価指標も、使いたいもの羅列してオプション指定するだけなので楽
    • 複数人でやる場合にも読みやすい(チームの場合はこれ結構重要かも)
  • デメリット
    • PytorchLightningのお作法を覚えないといけない
    • 他の人のコードを使いたいときも、PytorchLightningでやるには時間かかるときがありそう

Weights and Biases

これも便利。
まだ大してカスタマイズしてないけど、グラフにどう表示するかとか、結構好きに出来そうなので使っていきたい。

環境

  • Google colab

概要

  • モデル:VisionTransformer
  • データセット:CIFAR100
  • 可視化:t-SNE
  • その他:Weights & Biases

実装

Installing module

# ライブラリのインストール
!pip install -q pytorch_lightning
!pip install -q torchmetrics
!pip install --upgrade -q wandb
import os
import matplotlib.pyplot as plt

# Pytorch
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataset import Subset
import torchvision
import torchvision.transforms as T

# PytorchLightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassF1Score, MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
from pytorch_lightning.callbacks import Callback

# Weights and Biases
from pytorch_lightning.loggers import WandbLogger

# File
import sys
sys.path.append(base_dir)

# ViT用
import util
import eval # 未使用?

# Weights and Biasesへのログイン
import wandb
wandb.login()

Config

'''
ハイパーパラメータとオプションの設定
'''
# wandb(Weights & Biases)用に辞書型で定義
CONFIG = dict(
    save_dir = base_dir + "checkpoints",  # モデル保存先
    tsne_dir = base_dir + "t-sne_plots/", # t-sne画像の保存先
    num_classes = 10,                     # データセットのクラス数
    val_ratio = 0.2,                      # 検証に使う学習セット内のデータの割合
    num_workers = 2,                      # データローダに使うCPUプロセスの数
    num_epochs = 30,                       # 学習エポック数
    patience = 10,                        # early_stoppingのpatience
    num_samples = 200,                    # t-SNEでプロットするサンプル数
    batch_size = 32,                      # バッチサイズ
    lr_drop = 25,                         # 学習率を減衰させるエポック
    lr = 1e-2,                            # 学習率
    img_size = 32,                        # 入力画像の大きさ
    patch_size = 4,                       # パッチサイズ
    dim_hidden = 512,                     # 隠れ層の次元
    num_heads = 8,                        # マルチヘッドアッテンションのヘッド数
    dim_feedforward = 512,                # Transformerエンコーダ層内のFNNにおける隠れ層の特徴量次元
    num_layers = 6,                       # Transformerエンコーダの層数
)

'''
Weights and Biasesの設定
'''
WANDB = dict(
    project = "test-lightning",   # プロジェクトの名前
    group = "ViT",                # グループの名前
    name = "test01"               # 学習時の名前
)

Data module

'''
dataset: 平均と標準偏差を計算する対象のPyTorchのデータセット
'''
# 入力データ正規化のために学習セットのデータを使って、各チャネルの平均と標準偏差を計算
dataset = torchvision.datasets.CIFAR10(
    root='data', train=True, download=True,
    transform=T.ToTensor())
channel_mean, channel_std = util.get_dataset_statistics(dataset)
# 画像の整形を行うクラスのインスタンスを用意
data_transforms = {
    "train": T.Compose(
        [
            T.RandomResizedCrop(32, scale=(0.8, 1.0)), # 画像の一部を切り抜いてリサイズ
            T.RandomHorizontalFlip(), # 画像を水平反転(デフォで50%の確率)
            T.ToTensor(),
            T.Normalize(mean=channel_mean, std=channel_std),
        ]
    ),
    "val": T.Compose(
        [
            T.ToTensor(),
            T.Normalize(mean=channel_mean, std=channel_std),
        ]
    ),
    "test": T.Compose(
        [
            T.ToTensor(),
            T.Normalize(mean=channel_mean, std=channel_std),
        ]
    )
}
# 学習、評価セットの用意
train_dataset = torchvision.datasets.CIFAR10(
    root='data', train=True, download=True,
    transform=data_transforms["train"])

val_dataset = torchvision.datasets.CIFAR10(
    root='data', train=True, download=True,
    transform=data_transforms["val"])

test_dataset = torchvision.datasets.CIFAR10(
    root='data', train=False, download=True,
    transform=data_transforms["test"])

# 学習・検証セットへ分割するためのインデックス集合の生成
## trainとvalでデータ整形の方法が異なり(valはデータ拡張なし)、両者のデータに重複がないようにするために生成する
val_indices, train_indices = util.generate_subset(
    train_dataset, CONFIG["val_ratio"])

train_dataset = Subset(train_dataset, train_indices)
val_dataset   = Subset(val_dataset, val_indices)

image_datasets = {"train": train_dataset,
            "val": val_dataset,
            "test": test_dataset
            }

print(f'学習セットのサンプル数 : {len(image_datasets["train"])}')
print(f'検証セットのサンプル数 : {len(image_datasets["val"])}')
print(f'テストセットのサンプル数: {len(image_datasets["test"])}')
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
学習セットのサンプル数 : 40000
検証セットのサンプル数 : 10000
テストセットのサンプル数: 10000

DataLoaderのところだけPytorchLightningのModule使う
(この辺大して理解してない、たぶん構成は人によりけり)

class PlDataModule(pl.LightningDataModule):
    def __init__(self, image_datasets: dict, batch_size: int):
        super().__init__()
        self.image_datasets = image_datasets
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            self.image_datasets["train"], batch_size=self.batch_size, shuffle=True
        )

    def val_dataloader(self):
        return DataLoader(self.image_datasets["val"], batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.image_datasets["test"], batch_size=self.batch_size)

Metrics

MetricCollectionを使うと、複数の評価指標を簡単に出してくれる

"""
metrics for the training
"""
def get_full_metrics(
    threshold=0.5,
    average_method="macro",
    num_classes=None,
    prefix=None,
    ignore_index=None,
):
    return MetricCollection(
        [
            MulticlassAccuracy(
                threshold=threshold,
                num_classes=num_classes,
                ignore_index=ignore_index,
            ),
            MulticlassPrecision(
                threshold=threshold,
                average=average_method,
                num_classes=num_classes,
                ignore_index=ignore_index,
            ),
            MulticlassRecall(
                threshold=threshold,
                average=average_method,
                num_classes=num_classes,
                ignore_index=ignore_index,
            ),
            MulticlassF1Score(
                threshold=threshold,
                average=average_method,
                num_classes=num_classes,
                ignore_index=ignore_index,
            ),
        ],
        prefix= prefix
    )

Model module

以下書籍のコードをPytorchLightningに移植した

モデル構造は超簡単に書くと以下

  • 入力画像をパッチに分解
  • パッチ平坦化
  • 全結合層
  • Transformerエンコーダ
    • レイヤー正規化→マルチヘッドアテンション→レイヤー正規化→FNN)
  • レイヤー正規化
  • 全結合層
  • 出力

Multi Head Attention

マルチヘッドアテンション
Transformerエンコーダ層で使う

class SelfAttention(nn.Module):
    '''
    自己アテンション
    dim_hidden: 入力特徴量の次元
    num_heads : マルチヘッドアテンションのヘッド数
    qkv_bias  : クエリなどを生成する全結合層のバイアスの有無
    '''
    def __init__(self, dim_hidden: int, num_heads: int,
                 qkv_bias: bool=False):
        super().__init__()

        # 特徴量を各ヘッドのために分割するので、
        # 特徴量次元をヘッド数で割り切れるか検証
        assert dim_hidden % num_heads  == 0

        self.num_heads = num_heads

        # ヘッド毎の特徴量次元
        dim_head = dim_hidden // num_heads

        # ソフトマックスのスケール値
        self.scale = dim_head ** -0.5 # (2乗根の逆数、2乗根で実装されている書籍もあった)

        # ヘッド毎にクエリ、キーおよびバリューを生成するための全結合層
        self.proj_in = nn.Linear(
            dim_hidden, dim_hidden * 3, bias=qkv_bias) # 3はq, k, vの3つ分にするということ?

        # 各ヘッドから得られた特徴量を一つにまとめる全結合層
        self.proj_out = nn.Linear(dim_hidden, dim_hidden)

    '''
    順伝播関数
    x: 入力特徴量, [バッチサイズ, 特徴量数, 特徴量次元]
    '''
    def forward(self, x: torch.Tensor):
        bs, ns = x.shape[:2]

        qkv = self.proj_in(x) # 線形変換 + 特徴量を数を3倍?

        # view関数により
        # [バッチサイズ, 特徴量数, QKV, ヘッド数, ヘッドの特徴量次元]
        # permute関数により
        # [QKV, バッチサイズ, ヘッド数, 特徴量数, ヘッドの特徴量次元]
        qkv = qkv.view(
            bs, ns, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)

        # クエリ、キーおよびバリューに分解
        q, k, v = qkv.unbind(0)

        # クエリとキーの行列積とアテンションの計算(今回マスクは不使用)
        # attnは[バッチサイズ, ヘッド数, 特徴量数, 特徴量数]
        attn = q.matmul(k.transpose(-2, -1))
        attn = (attn * self.scale).softmax(dim=-1)

        # アテンションとバリューの行列積によりバリューを収集
        # xは[バッチサイズ, ヘッド数, 特徴量数, ヘッドの特徴量次元]
        x = attn.matmul(v)

        # permute関数により
        # [バッチサイズ, 特徴量数, ヘッド数, ヘッドの特徴量次元]
        # flatten関数により全てのヘッドから得られる特徴量を連結して、
        # [バッチサイズ, 特徴量数, ヘッド数 * ヘッドの特徴量次元]
        x = x.permute(0, 2, 1, 3).flatten(2)
        x = self.proj_out(x) # 線形変換

        return x

FNN(Transformerエンコーダ内)

Transformerエンコード層の中で使う順伝播型ニューラルネットワーク(feedforward neural networks)の定義

class FNN(nn.Module):
    '''
    Transformerエンコーダ内の順伝播型ニューラルネットワーク
    dim_hidden     : 入力特徴量の次元
    dim_feedforward: 中間特徴量の次元
    '''
    def __init__(self, dim_hidden: int, dim_feedforward: int):
        super().__init__()

        self.linear1 = nn.Linear(dim_hidden, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, dim_hidden)
        self.activation = nn.GELU()

    '''
    順伝播関数
    x: 入力特徴量, [バッチサイズ, 特徴量数, 特徴量次元]
    '''
    def forward(self, x: torch.Tensor):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)

        return x

Transformerエンコーダ層

Transformerエンコーダ層の定義(上のFNNクラスとSelfAttentionクラスを使用)

class TransformerEncoderLayer(nn.Module):
    '''
    Transformerエンコーダ層
    dim_hidden     : 入力特徴量の次元
    num_heads      : ヘッド数
    dim_feedforward: 中間特徴量の次元
    '''
    def __init__(self, dim_hidden: int, num_heads: int,
                 dim_feedforward: int):
        super().__init__()

        self.attention = SelfAttention(dim_hidden, num_heads)
        self.fnn = FNN(dim_hidden, dim_feedforward)

        self.norm1 = nn.LayerNorm(dim_hidden)
        self.norm2 = nn.LayerNorm(dim_hidden)

    '''
    順伝播関数
    x: 入力特徴量, [バッチサイズ, 特徴量数, 特徴量次元]
    '''
    def forward(self, x: torch.Tensor):
        x = self.norm1(x)
        x = self.attention(x) + x
        x = self.norm2(x)
        x = self.fnn(x) + x

        return x

Vision Transformerの実装

class VisionTransformer(nn.Module):
    '''
    Vision Transformer
    num_classes    : 分類対象の物体クラス数
    img_size       : 入力画像の大きさ(幅と高さ等しいことを想定)
    patch_size     : パッチの大きさ(幅と高さ等しいことを想定)
    dim_hidden     : 入力特徴量の次元
    num_heads      : マルチヘッドアテンションのヘッド数
    dim_feedforward: FNNにおける中間特徴量の次元
    num_layers     : Transformerエンコーダの層数
    '''
    def __init__(self, num_classes: int, img_size: int,
                 patch_size: int, dim_hidden: int, num_heads: int,
                 dim_feedforward: int, num_layers: int):
        super().__init__()

        # 画像をパッチに分解するために、画像の大きさがパッチの大きさで割り切れるか確認
        assert img_size % patch_size == 0

        self.img_size = img_size
        self.patch_size = patch_size

        # パッチの行数と列数はともにimg_size // patch_sizeであり、
        # パッチ数はその2乗になる
        num_patches = (img_size // patch_size) ** 2

        # パッチ特徴量はパッチを平坦化することにより生成されるため
        # その次元はpatch_size * patch_size * 3(RGBチャネル)
        dim_patch = 3 * patch_size ** 2

        # パッチ特徴量をTransformerエンコーダーに入力する前に
        # パッチ特徴量の次元を変換する全結合層
        self.patch_embed = nn.Linear(dim_patch, dim_hidden)

        # 位置埋め込み(パッチ数 + クラス埋め込みの分を用意)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, dim_hidden))

        # クラス埋め込み
        self.class_token = nn.Parameter(
            torch.zeros((1, 1, dim_hidden))
        )

        # Transformerエンコーダ層
        self.layers = nn.ModuleList([TransformerEncoderLayer(
            dim_hidden, num_heads, dim_feedforward
        ) for _ in range(num_layers)])

        # ロジット(ニューロンの出力値)を生成する前のレイヤー正規化と全結合
        self.norm = nn.LayerNorm(dim_hidden)
        self.linear = nn.Linear(dim_hidden, num_classes)

    '''
    順伝播関数
    x           : 入力, [バッチサイズ, 入力チャネル数, 高さ, 幅]
    return_embed: 特徴量を返すかロジットを返すかを選択する真偽値
    '''
    def forward(self, x: torch.Tensor, return_embed: bool=False):
        bs, c, h, w = x.shape

        # 入力画像の大きさがクラス生成時に指定したimg_sizeと
        # 合致しているか確認
        assert h == self.img_size and w == self.img_size

        # 高さ軸と幅軸をそれぞれパッチ数 * パッチの大きさに分解し、
        # [バッチサイズ, チャネル数, パッチの行数, パッチの大きさ,
        # パッチの列数, パッチの大きさ] の形にする
        x = x.view(bs, c, h // self.patch_size, self.patch_size,
                w // self.patch_size, self.patch_size)

        # permute関数により、
        # [バッチサイズ, パッチ行数, パッチ列数, チャネル,
        #                   パッチの大きさ, パッチの大きさ] の形にする
        x = x.permute(0, 2, 4, 1, 3, 5)

        # パッチを平坦化
        # permute関数適用後にはメモリ上のデータ配置の整合性の関係で
        # view関数を使えないのでreshape関数を使用
        x = x.reshape(
            bs, (h // self.patch_size) * (w // self.patch_size), -1)

        x = self.patch_embed(x)

        # クラス埋め込みをバッチサイズ分用意
        class_token = self.class_token.expand(bs, -1, -1)

        x = torch.cat((class_token, x), dim=1)

        x += self.pos_embed

        # Transformerエンコーダ層を適用
        for layer in self.layers:
            x = layer(x)

        # クラス埋め込みをベースとした特徴量を抽出
        x = x[:, 0]

        x = self.norm(x)

        if return_embed: # Trueなら特徴量、Falseでロジットを返す?
            return x

        x = self.linear(x)

        return x

LightningModule

VisionTransformerクラスを実装

class PlViTModule(pl.LightningModule):
    # def __init__(self, lr: float, num_classes: int):
    def __init__(self):
        super(PlViTModule, self).__init__()

        # 各パラメータ等定義
        self.lr = CONFIG['lr']
        self.num_classes = CONFIG['num_classes']
        self.img_size = CONFIG['img_size']
        self.patch_size = CONFIG['patch_size']
        self.dim_hidden = CONFIG['dim_hidden']
        self.num_heads = CONFIG['num_heads']
        self.dim_feedforward = CONFIG['dim_feedforward']
        self.num_layers = CONFIG['num_layers']

        self.loss_fn = nn.CrossEntropyLoss()

        # モデル定義
        self.models = VisionTransformer(num_classes = self.num_classes,
                                        img_size = self.img_size,
                                        patch_size = self.patch_size,
                                        dim_hidden = self.dim_hidden,
                                        num_heads = self.num_heads,
                                        dim_feedforward = self.dim_feedforward,
                                        num_layers = self.num_layers)

        # 評価指標定義
        ## 学習、評価、テストそれぞれに対して指定できる
        self.train_metrics = get_full_metrics(
            num_classes = self.num_classes,
            prefix = "train_",
        )
        self.valid_metrics = get_full_metrics(
            num_classes = self.num_classes,
            prefix = "valid_",
        )
        self.test_metrics = get_full_metrics(
            num_classes = self.num_classes,
            prefix = "test_",
        )

        # ハイパーパラメーターをself.hparamsに保存する (wandbによって自動ロギングされる)
        self.save_hyperparameters()

    def forward(self, x: torch.Tensor, return_embed: bool = False): # ここでもreturn_embed必要
        output = self.models(x)
        return output

    # 学習時のステップ(各バッチ毎に実行する処理)
    def training_step(self, batch, batch_idx):
        images, target = batch
        preds = self.forward(images)
        # クロスエントロピー、モデルの学習はこの指標でやる
        loss = self.loss_fn(preds, target)
        # print(preds, target, loss)

        # こっちは考察用のメトリクス
        preds_for_metrics = F.softmax(preds, dim=1)
        self.train_metrics(preds_for_metrics, target)

        # 単一のログを記録
        self.log(
            "train_loss",
            loss,
            prog_bar=True,
            logger=True,
            on_epoch=True,
            on_step=True,
        )
        # 複数のログを記録
        self.log_dict(
            self.train_metrics,
            prog_bar=False,
            logger=True,
            on_epoch=True,
            on_step=True,
        )
        return {"loss": loss}  # このlossを基にモデルが更新される

    # 評価時のステップ
    def validation_step(self, batch, batch_idx):
        images, target = batch
        preds = self.forward(images)
        loss = self.loss_fn(preds, target)

        preds_for_metrics = F.softmax(preds, dim=1)

        self.valid_metrics(preds_for_metrics, target)

        self.log(
            "valid_loss",
            loss,
            prog_bar=True,
            logger=True,
            on_epoch=True,
            on_step=True,
        )
        self.log_dict(
            self.valid_metrics,
            prog_bar=True,
            logger=True,
            on_epoch=True,
            on_step=True,
        )
        return {"valid_loss": loss, "preds": preds_for_metrics} # validation_endメソッドのoutputsにはこのvalid_lossとpredsが入る

    # 検証終了時のステップ
    def validation_end(self, outputs):
        avg_loss = torch.stack([x["valid_loss"] for x in outputs]).mean()
        return {"avg_val_loss": avg_loss}

    # テスト時のステップ
    def test_step(self, batch, batch_idx):
        images, target = batch
        preds = self.forward(images)
        loss = self.loss_fn(preds, target)

        preds_for_metrics = F.softmax(preds, dim=1)
        self.test_metrics(preds_for_metrics, target)

        self.log(
            "test_loss",
            loss,
            prog_bar=True,
            logger=True,
            on_epoch=True,
            on_step=True,
        )
        self.log_dict(
            self.test_metrics,
            prog_bar=True,
            logger=True,
            on_epoch=True,
            on_step=True,
        )

        return {"test_loss": loss}

    # オプティマイザーの設定
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.models.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
        return [optimizer], [scheduler]

Logger and callback

model_checkpointとearly_stoppingの設定。
model_checkpointはモデルの保存や監視指標等を設定できる。

# callbacksの定義
model_checkpoint = ModelCheckpoint(
    filename = WANDB['group'] + "_" + WANDB['name'] + "_{epoch}_{valid_loss:.4f}",  # モデルファイルの名前
    dirpath = CONFIG["save_dir"],           # モデルを保存するディレクトリのパス
    monitor = "valid_loss",                 # 監視する指標
    mode = "min",                           # 最小化を目指す
    save_top_k = 1,                         # 最良のK個のモデルを保存
    save_last = False,                      # 最後のエポックのモデルを保存
)

early_stopping = EarlyStopping(
    monitor="valid_loss",
    mode="min",
    patience=CONFIG["patience"],
)

LogPredictionSamplesCallbackではwandbへの保存設定をしている
(検証バッチ毎にサンプル画像と予測結果を保存)

class LogPredictionSamplesCallback(Callback):

    # 検証バッチが終了したときに呼び出される
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx):

        # `outputs`は`LightningModule.validation_step`からくる
        preds = torch.argmax(outputs["preds"], -1)

        # 1つ目のバッチから20個のサンプル画像の予測をログに記録
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [f'正解: {y_i} - 予測: {y_pred}' for y_i, y_pred in zip(y[:n], preds[:n])]

            # オプション1:`WandbLogger.log_image`を使って画像をログに記録する
            wandb_logger.log_image(
                key='sample_images',
                images=images,
                caption=captions)

            # オプション2:画像と予測をW&Bテーブルとしてログに記録する
            columns = ['image', 'ground truth', 'prediction']
            data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs["preds"][:n]))]
            wandb_logger.log_table(
                key='sample_table',
                columns=columns,
                data=data)

Trainer

学習実行時の設定

# wandb loggerの設定
wandb_logger = WandbLogger(project = WANDB["project"],
                           config = CONFIG, 
                           group = WANDB["group"],
                           name = WANDB["name"],
                           job_type='train')

# Trainerの設定
trainer = pl.Trainer(
    max_epochs=CONFIG['num_epochs'],
    accelerator="auto", # "gpu"
    callbacks=[model_checkpoint, early_stopping, LogPredictionSamplesCallback()],
    logger=[wandb_logger],
)

Run

学習実行

data_module = PlDataModule(image_datasets, CONFIG["batch_size"])
model = PlViTModule()

# 学習実行
trainer.fit(
    model,
    datamodule=data_module,
)

# テスト実行
## 自動的に学習結果からbest weightsをロードするらしい
trainer.test(dataloaders=data_module.test_dataloader())
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type              | Params
----------------------------------------------------
0 | loss_fn       | CrossEntropyLoss  | 0     
1 | models        | VisionTransformer | 9.5 M 
2 | train_metrics | MetricCollection  | 0     
3 | valid_metrics | MetricCollection  | 0     
4 | test_metrics  | MetricCollection  | 0     
----------------------------------------------------
9.5 M     Trainable params
0         Non-trainable params
9.5 M     Total params
38.095    Total estimated model params size (MB)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃          Test metric           ┃          DataLoader 0          ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ test_MulticlassAccuracy_epoch  │      0.10000000149011612       │
│  test_MulticlassF1Score_epoch  │      0.01818181946873665       │
│ test_MulticlassPrecision_epoch │      0.009999999776482582      │
│  test_MulticlassRecall_epoch   │      0.10000000149011612       │
│        test_loss_epoch         │       2.3063504695892334       │
└────────────────────────────────┴────────────────────────────────┘
[{'test_loss_epoch': 2.3063504695892334,
  'test_MulticlassAccuracy_epoch': 0.10000000149011612,
  'test_MulticlassPrecision_epoch': 0.009999999776482582,
  'test_MulticlassRecall_epoch': 0.10000000149011612,
  'test_MulticlassF1Score_epoch': 0.01818181946873665}]

こんな感じで評価指標まで自動で算出してくれる
上記はepoch2の結果だが、30くらいやると6-7割の正答率にはなる

t-SNE

書籍でやっていたので、ついでにこれも移植
(画像はepoch2の結果なので、全然分類できてない)

# ベストモデルのパスからファイル名を取得
name, _ = os.path.splitext(best_model_path.split("/")[-1])

# 保存パス
tsne_path = CONFIG["tsne_dir"] + WANDB['group'] + "_" + name + ".png"

# t-SNEを使って特徴量の分布をプロット
util.plot_t_sne(data_module.test_dataloader(), model, CONFIG["num_samples"], model.device, tsne_path)

# プロット画像をwandbに保存
wandb_logger.log_image(
    key = 't-SNE Plot',
    images = [plt.imread(tsne_path)],
)

# Close wandb
wandb.finish()

output_38_0.png

Weights & Biases

結果を溜めて可視化できる。
ViTとResNetでグループを分けたり、比較したい結果を選んで可視化することもできそう。
何回も学習をすると、過去試した内容や比較がこんがらがってくるので、便利。
Slack等と連携して、学習終了時に通知出すとかもできる。

image.png
image.png
image.png
image.png

参照

非常に参考になりました。ありがとうございます。

以上

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?