LoginSignup
2
1

transformersのモデルのダウンロードについて

Last updated at Posted at 2024-02-21

目的

transformersのモデルをローカルにダウンロードする方法を示す。

動機

transformersで作業を続けているとき、アップデートが
入ると再ダウンロードが必要になる。Phi-2で約5GBの保存容量が必要になり
ダウンロードの度にそれなりの時間がかかり煩わしい。(光回線が引けない家🥺)

ダウンロードするためのコード

AutoTokenizerやAutoModelForCausalLMのsave_pretrained関数[1]のsave_directoryという引数に
保存先を指定するとモデルを指定場所に保存できる。
保存したあとはfrom_pretrainedのpretrained_model_name_or_path引数に保存した場所を指定すれば
モデルを読み込み使用できる。


from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM

def get_tokenizer_and_model(pre_trained_model_name, save_dir)
    if not Path(save_dir).exists():
        # download model
        self.tokenizer = AutoTokenizer.from_pretrained(pre_trained_model_name)
        self.model = AutoModelForCausalLM.from_pretrained(pre_trained_model_name)
        # save model
        self.tokenizer.save_pretrained(save_dir)
        self.model.save_pretrained(save_dir)

    else:
        # load model
        tokenizer = AutoTokenizer.from_pretrained(save_dir)
        model = AutoModelForCausalLM.from_pretrained(save_dir)

"the attention mask and the .."という警告

しかし、ダウンロードしたモデルで推論させようとすると次のような警告が表示される。

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.

これはモデルにパディングトークンが渡されていないために発生する。
通常はパディングトークン=EOSトークンとして扱ってしまって問題ないので
次のコードのようにpad_token_idにtokenizer.eos_token_idを設定する。[2]

model.generate(**encoded_input, pad_token_id=tokenizer.eos_token_id)

そうすることで上記の警告は発生しないように。

まとめ

時々入る再ダウンロードに煩わされることがなく快適に!!

参考サイト

save_pretrained関数について

"the attention mask and the .."という警告に関する質問

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1