2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Huggingfaceの事前学習済みモデルに層を追加し,追加した層のみを学習対象とする

Last updated at Posted at 2023-12-19

概要

以下を実現する方法の一つを紹介

  • huggingfaceのモデルにAdapterやLoRA,クラス分類用のHeadといった層を追加する
    • from_pretrained()でロード出来た重みはrequires_grad=False(勾配計算をしない)
    • ロード出来なかった重み(自作の層)はrequires_grad=True(勾配計算をする)

実装

事前学習済みモデルを取得する

Huggingfaceでは公開されているモデルを簡単に入手できます.
今回は例としてCLIPVisionModelを用います(ドキュメント).

from transformers import CLIPVisionModel, CLIPVisionConfig

model_name = "openai/clip-vit-base-patch32"
config = CLIPVisionConfig.from_pretrained(model_name)
model = CLIPVisionModel.from_pretrained(
    model_name,
    config=config,
)

自作モデルを作り事前学習済みの重みをロードする

自作モデルの例として,CLIPVisionModelにクラス分類用の層を加えます.
なおCLIPでクラス分類を行いたい場合,HuggingfaceにはCLIPVisionModelWithProjectionが用意されているため,わざわざこのようにする必要はありませんが,説明のために作成します.

import torch
import torch.nn as nn

class MyCLIPVisionModel(CLIPVisionModel):
    def __init__(self, config: CLIPVisionConfig, num_classes: int) -> None:
        super().__init__(config)
        self.head = nn.Linear(config.hidden_size, num_classes)

    def forward(
        self,
        pixel_values: torch.Tensor,
        output_attentions: bool,
        output_hidden_states: bool,
        return_dict: bool
    ) -> torch.Tensor:
        vision_outputs = super().forward(
            pixel_values,
            output_attentions,
            output_hidden_states,
            return_dict
        )
        image_embeds = vision_outputs.pooler_output
        output = self.head(image_embeds)
        return output

自作のMyCLIPVisionModelPreTrainedModelを継承しているため,from_pretrained()を使って事前学習済みの重みを取得することが出来ます.

model_name = "openai/clip-vit-base-patch32"
config = CLIPVisionConfig.from_pretrained(model_name)
num_classes = 10
model = MyCLIPVisionModel.from_pretrained(
    model_name,
    config=config,
    num_classes=num_classes
)

追加した層のみ学習対象とする

先ほどの自作クラスからfrom_pretrained()を使って事前学習済みの重みをロードした際に,以下2つの警告が出ました.(1つ目の警告は長いので省略して載せています)

  • 警告1つ目
Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing MyCLIPVisionModel: ['text_model.encoder.layers.7.mlp.fc1.weight', ..., 'text_model.encoder.layers.11.self_attn.out_proj.weight']

CLIPというモデルはテキストエンコーダと画像エンコーダが存在します.
この警告は,「"openai/clip-vit-base-patch32"のcheckpointに存在する事前学習の重みを持つ層が,自作モデルには存在していない」ことを意味しています.
つまり,テキストエンコーダのcheckpointが保存されているのに使われていないという警告なので,画像エンコーダのみを使いたい場合は無視して問題ありません.

  • 警告2つ目
Some weights of MyCLIPVisionModel were not initialized from the model checkpoint at openai/clip-vit-base-patch32 and are newly initialized: ['head.weight', 'head.bias']

ここに書かれた['head.weight', 'head.bias']の層は事前学習済みモデルには存在しないため,「事前学習済みモデルで初期化できない,新たに初期化した」と言ってきます(当然).

この警告から,['head.weight', 'head.bias']以外の層,すなわち元のCLIPVisionModelの層は事前学習の重みを取得出来ていることがわかります.
しかし,このまま学習すると,全ての層が学習対象となります.

# 全ての層が学習対象となっていることを確認
array = []
for param in model.parameters():
    array.append(param.requires_grad)
print(all(array)) # True

そこで追加した層のみrequires_grad=Trueとするため,警告で出てくる['head.weight', 'head.bias']を取得します

from_pretrained()のドキュメントを見ると,output_loading_infoという引数が用意されています.
これを使うと警告で出てくるパラメータ名を取得出来ます.

model, loading_info = MyCLIPVisionModel.from_pretrained(
    model_name,
    config=config,
    num_classes=num_classes,
    output_loading_info=True,
)
# 警告1つ目のパラメータ名
print(loading_info['unexpected_keys']) # ['text_model.encoder.layers.7.mlp.fc1.weight', ..., 'text_model.encoder.layers.11.self_attn.out_proj.weight']
# 警告2つ目のパラメータ名
print(loading_info['missing_keys']) # ['head.weight', 'head.bias']

最終的には以下のようにして今回やりたかったことを実現出来ます.

model_name = "openai/clip-vit-base-patch32"
config = CLIPVisionConfig.from_pretrained(model_name)
num_classes = 10
model, loading_info = MyCLIPVisionModel.from_pretrained(
    model_name,
    config=config,
    num_classes=num_classes,
    output_loading_info=True,
)
for name, param in model.named_parameters():
    if name in loading_info['missing_keys']:
        param.requires_grad = True
    else:
        param.requires_grad = False

参照

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?