Why not login to Qiita and try out its useful features?

We'll deliver articles that match you.

You can read useful information later.

1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

HuggingFaceのBARTで入力シーケンス長を変更する

Posted at

HuggingFaceで公開されているBARTはデフォルトで1024トークンが入力の最大長
要約のため、Webページまるまる1つとかとんでもない長文を入れようとすると、これでは足りなくなってくる
今回は入力シーケンス長を変更する手順をまとめ、パラメータコピーを行う関数を作る

やってることはレイヤー名が同じで一部のシェイプが異なるPytorchモデルのパラメータコピーである

注意点

入力シーケンス長を変えただけでは、学習されてないパラメータが残ってしまう
そのため、実際に使う前に必ずfine-tuningを行う必要がある

configの変更

huggingfaceのドキュメントを読むと、max_position_embeddingsがモデルの入力シーケンス長を決定していることがわかる
https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartConfig.max_position_embeddings

これを変更することで、モデルの入力長は変更ができる
MBartConfigに与えているjsonファイルはHuggingFaceレポジトリページのFiles and versionsから持ってこれる

from transformers import MBartForConditionalGeneration, MBartConfig
config = MBartConfig.from_json_file("./config.json")
config.max_position_embeddings = 1024*8
extend_model = MBartForConditionalGeneration(config)

パラメータのコピー

本記事一番のキモ
from_pretrainedは引数でmax_position_embeddingsを受け付けられるため、ここを変えれば楽勝やん!でやろうとすると、見事エラーの餌食になる

>>> model = AutoModelForSeq2SeqLM.from_pretrained("Formzu/bart-large-japanese", max_position_embeddings = 1024*8)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/auto/auto_factory.py", line 463, in from_pretrained
    return model_class.from_pretrained(
  File "/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py", line 2379, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py", line 2695, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for MBartForConditionalGeneration:
        size mismatch for model.encoder.embed_positions.weight: copying a param with shape torch.Size([1026, 1024]) from checkpoint, the shape in current model is torch.Size([8194, 1024]).
        size mismatch for model.decoder.embed_positions.weight: copying a param with shape torch.Size([1026, 1024]) from checkpoint, the shape in current model is torch.Size([8194, 1024]).
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

まぁ形状変えてるし当然ではある
ここで仮にignore_mismatched_sizes=Trueを入れてしまうと言語モデルとして重要なパラメータがコピーされず、ほぼ初期状態の出力が出てきてしまう
ここでゼロから言語モデルを学習させるわけにもいかないので、ゴリゴリと泥臭くパラメータをコピーすることにする

これは先駆者がおり、特定レイヤーのパラメータコピーの方法を書いている
https://stackoverflow.com/questions/62603089/config-change-for-a-pre-trained-transformer-model

今回はこれを改変し、”レイヤー内容は同一でパラメータのシェイプが異なる” モデルのパラメータコピーを行えるようにする

そうしてできた関数が以下

def weight_copy(source_model, target_model, max_pos_size):
    tensor_name_list = list(target_model.state_dict().keys())
    for model_tensor_name in tensor_name_list:
        if target_model.state_dict()[model_tensor_name].shape == source_model.state_dict()[model_tensor_name].shape:
            target_model.state_dict()[model_tensor_name][:] = source_model.state_dict()[model_tensor_name]
        else:
            target_model.state_dict()[model_tensor_name][:max_pos_size] = source_model.state_dict()[model_tensor_name][:max_pos_size]

主にこの記事 https://qiita.com/mathlive/items/d9f31f8538e20a102e14 を参考にした
特に代入の際に参照要素を指定する必要があり、[:]とわざわざ書かないとパラメータがコピーされなかったりする
シェイプが異なる際に[:max_pos_size]としているのは、シーケンシャルだし左詰めでいいだろ……という安易な予想である
なお、モデルは参照渡しされるため、returnでモデルを返す必要はない

コピー後

あとはfine-tuningすれば思い通りのモデルが手に入る……はず
やり終わって気が向いたらまとめます

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?