3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

transformersのAutoTokenizerのuse_fastのデフォルト値がバージョンによって異なっていた

Last updated at Posted at 2022-07-01

要約

transformersライブラリのバージョンを上げるときは、AutoTokenizer.from_pretrained で use_fastを指定しましょうという話


Huggingfaceのtransformersライブラリを使っていると、古いバージョンから新しいバージョンに切り替えたい時があると思います。

自分は3.5.0から4.18.0に切り替えようとした時、実際は単純だったのに謎にハマったエラーがあったので備忘録も兼ねて共有します。

問題があったのは以下のコードです。

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("roberta-base")

これを実行して得られるtokenizerが、transformersのバージョンによって異なるという話です。

具体的には、以下のようになります。

3.5.0
>>> tokenizer
PreTrainedTokenizer(name_or_path='roberta-base', vocab_size=50265, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True)})
4.18.0
>>> tokenizer
PreTrainedTokenizerFast(name_or_path='roberta-base', vocab_size=50265, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})

このように、得られるtokenizerのクラス自体がFastかそうでないかという違いがありました。

怖いですね。

自分の場合はこの違いによって、文を入れたときのtokenizerの出力が両者で異なることがバージョンをあげた時に起こるエラーの原因になっていました。

バージョンによって異なる理由は、バージョンによってAutoTokenizerのuse_fastのデフォルト値が異なるためでした。(TrueかFalseか)

解決策は、明示的にuse_fastを指定するだけです。

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("roberta-base", use_fast=False)

これで一安心です。

参考リンク

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?