1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【深層学習】モデルの部分ロード(層別の学習率の設定含)

Posted at

こんにちは!深層学習のモデルを部分ロードして学習率を変える方法のメモ書きです。
大きく4つの学習を方法を取り上げます。

  1. Finetune
  2. 転移学習
  3. 転移学習+最終block学習
  4. モデルの一部形状を変えて学習

準備

  • dataloader作成(今回はHAM10000データセットを時間短縮のため層別サンプリングをして半分にしてからtrainとvalidに分割)
  • モデルはtimmのpretrain済みの vit_base_patch16_224 を使用

1. Finetune

モデルのロード

import timm

def build_model(model_name, pretrained, num_classes):
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
    return model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 7
model = build_model(model_name="vit_base_patch16_224", pretrained=True, num_classes=num_classes).to(device)

学習率の設定

finetuningは全層再学習になるので、層ごとグラデーションのある学習率をかけていきます。
スケジューラも併用するので、実際の学習率は以下の式で計算されます。
(スケジューラ係数:時間依存)×(層別倍率:固定)×(ベースLR)

まずは、param_group を作成します。

from timm.optim.optim_factory import param_groups_layer_decay
from timm.scheduler import CosineLRScheduler
import torch

# LLRDでパラメタグループ作成(層ごと倍率を自動付与)
groups = param_groups_layer_decay(
    model,
    weight_decay=1e-2, # 減衰
    layer_decay=0.75,  # 深い層ほど大きく動かす
    no_weight_decay_list=getattr(model, "no_weight_decay", lambda: set())(),
    # メソッドがあれば返す、なければ空集合を返す
)

no_weight_decay_list をモデルが持っていれば、自動的にno-decay(正則化をかけない)対象が設定されます。
weight_decay がかかるということはモデルの重みが不必要に大きくなるのを防いでくれますが、以下のパラメータにはこのような正則化を掛ける必要がありません。

  • bias
  • BatchNorm / LayerNorm の weight / bias
  • ViTの pos_embed, cls_token

※ timmモデルには、model.no_weight_decay()が実装されているので今はあまり気にする必要はありません。何が入るかはモデル側の実装次第です

もしtimmモデル以外を使用する場合はdecaynodecayを自分で作ります。

decay, nodecay = [], []
for n, p in model.named_parameters():
    if not p.requires_grad: continue
    if p.ndim == 1 or any(k in n.lower() for k in ["bias","bn","ln","norm"]) or n in ["pos_embed","cls_token"]:
        nodecay.append(p)
    else:
        decay.append(p)
optimizer = torch.optim.AdamW([
    {"params": decay,   "weight_decay": 1e-2},
    {"params": nodecay, "weight_decay": 0.0},
], lr=1e-4)

定義した param_groupoptimizer にわたします。

# Optimizer(各グループに倍率が掛かる)
optimizer = torch.optim.AdamW(
    groups,              # param_groups(LR倍率やWDが含まれる)
    lr=1e-4,             # ベースLR(各グループに倍率が掛かる)
    betas=(0.9, 0.999),  # 一次/二次モーメントの減衰率
    eps=1e-8             # 数値安定化の微小値
)

lr は基準となる学習率であり、実際の学習率は先ほど示したとおり(スケジューラ係数:時間依存)×(層別倍率:固定)×(ベースLR)となります。
betas はAdamの一階/二階モーメントの減衰率ですが、デフォルト通りです。

ではスケジューラを設定します。

# timmのCosineスケジューラ(warmup付き)
scheduler = CosineLRScheduler(
    optimizer,
    t_initial=num_epochs,     # コサインの周期(エポック数基準)
    lr_min=0.0,               # 最終的に近づける最小LR
    warmup_t=3,               # ウォームアップ長(エポック数)
    warmup_lr_init=1e-6,      # ウォームアップ開始LR(ベースLRより小さく)
    t_in_epochs=True,         # True=epoch単位, False=iteration単位
)

今回は、3エポックでウォームアップして、コサインで減衰していき最終的には0になるような設定です。

今回はtimmの param_groups_layer_decayCosineLRScheduler を使っています。

結果

num_epoch=10で学習した結果が以下です。

Best Epoch: 7
Best Val Loss: 0.3733849522279091
Best Acc: 0.7369269728660583

image.png

学習率の推移を見てもきちんと動いていることが確認できました。
image.png

転移学習

転移学習は最後の分類層のみ学習する方法です。

重みを固定する関数の定義

以下の関数で重みと固定するパラメータと更新可能にするパラメータを指定します。

def freeze(module):
    """重みを固定する"""
    for p in module.parameters():
        p.requires_grad = False

def unfreeze(module):
    """重みを更新可能にする"""
    for p in module.parameters():
        p.requires_grad = True # デフォルトはTrue

最後の分類層のみ学習

freeze(model)
unfreeze(model.head)

線形層だけなので以下のように再設定しました。optimizerの設定が変わったのでschedulerも再設定します。

optimizer = torch.optim.AdamW([
    {"params": model.head.parameters(), "lr": 1e-4, "weight_decay": 0.0}
], betas=(0.9, 0.999), eps=1e-8)

scheduler = CosineLRScheduler(
    optimizer, t_initial=num_epochs,
    lr_min=0.0, warmup_t=5, warmup_lr_init=1e-6, t_in_epochs=True
)

結果

Best Epoch: 9
Best Val Loss: 0.8637985422543026
Best Acc: 0.1871490627527237

image.png

転移学習+最終block学習

分類層に加えて、最終blockのみ学習するような設定にしてみます。
先ほど、分類層はunfreezeしたので加えてblock[-1]もunfreezeします。

unfreeze(model.blocks[-1])

学習率の設定は以下のとおり行いました

optimizer = torch.optim.AdamW([
    {"params": model.blocks[-1].parameters(), "lr": 1e-5, "weight_decay": 1e-2},
    {"params": model.head.parameters(),       "lr": 1e-4, "weight_decay": 0.0},  # head大きめ&no-decay
], betas=(0.9, 0.999), eps=1e-8)
scheduler = CosineLRScheduler(
    optimizer, t_initial=num_epochs,
    lr_min=0.0, warmup_t=3, warmup_lr_init=1e-6, t_in_epochs=True
)

結果

Best Epoch: 8
Best Val Loss: 0.7190584313924996
Best Acc: 0.3278992474079132

image.png

4. モデルの形状を変えて学習

事前学習済みのモデルの一部の形状や構造を変えたいというニーズがある状況です。

パラメータ名が異なるモデルに重みをそのままロードしてもエラーになってしまいます。
まずは新しいモデルと既存の重みがどのように異なるのかを調べる必要があります。

重みをロード

ckpt = torch.load("best_loss_model.pth", map_location="cpu")

今回は、新しい別のモデルに学習済みのモデルのpthを読み込ませるという想定です。
今まで使っていたモデルと少しだけ異なるようなモデルを適当に選んでみました。

num_classes = 7
model = build_model(model_name="vit_base_patch16_clip_224.openai", pretrained=False, num_classes=num_classes).to(device)

load_state_dictをする際に strict=False とすると、パラメータのキーが一致するものだけロードすることになります。

def model_load(model, ckpt, strict=False):
    msg = model.load_state_dict(ckpt, strict=strict)
    print("=== load report ===")
    print("missing_keys   :", getattr(msg, "missing_keys", []))
    print("unexpected_keys:", getattr(msg, "unexpected_keys", []))
    return msg
msg = model_load(new_model, ckpt, strict=False)

返り値のmsgには2つの属性があります。

  • missing_keys: (モデル側)一致しないモデルのパラメータを表示
  • unexpected_keys: (重み側)予期しないモデルのパラメータを表示

実際に行うと以下の出力となります。

msg = model_load(model, ckpt, strict=False)
=== load report ===
missing_keys   : ['norm_pre.weight', 'norm_pre.bias']
unexpected_keys: ['patch_embed.proj.bias']

なお、同名のパラメータですが形の違うテンソルが含まれると、strict=False でもサイズ不一致でエラーになってしまいます。
この場合、名前と形が一致するものだけfileterしてから load_state_dict(strict=False) するとOKです。

from collections import OrderedDict
def filter_by_name_and_shape(model, src_sd):
    """名前と形が一致するキーだけを抽出"""
    dst_sd = model.state_dict()
    filtered = OrderedDict()
    for k, v in src_sd.items():
        if (k in dst_sd) and (v.shape == dst_sd[k].shape):
            filtered[k] = v
    return filtered

filt = filter_by_name_and_shape(model, ckpt)
model.load_state_dict(filt, strict=False)

今回は、finetuneで全層学習させると一番良い結果になりました。
以上です。読んでいただきありがとうございました。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?