4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【備忘録】Sentencepieceの学習時にオーバーフローを防ぎつつspecial tokenを正しく設定する方法

Posted at

公式のレポジトリにもいくつか説明箇所はあるのですが記法などが統一されていなかったりして少し調べるのに戸惑ったので、こちらのブログに書き残しておくことにしました。

なお、公式のレポジトリは以下になります。

Sentencepiceの学習について

Pythonインターフェースでのコードですが、非常にシンプルです。ググってよく出てくるサンプルだと例えば以下のような感じだと思います。

import sentencepiece as spm

# モデルの学習
spm.SentencePieceTrainer.Train(
    input='corpus.txt',
    model_prefix='sentencepiece',
    vocab_size=32000,
    character_coverage=0.9995
)

# こんな感じに書くこともできる
spm.SentencePieceTrainer.Train('--input=corpus.txt --model_prefix=sentencepiece --vocab_size=32000 character_coverage=0.9995')

オーバーフローを防ぐ

(多分、、、)テキストファイルを一気にロードするため、大きなコーパスを学習させる際はCPUのコア数を増やす必要があります。その際以下の引数を追加することで、オーバーフローを防ぐことが可能です。

import sentencepiece as spm


spm.SentencePieceTrainer.Train(
    input='corpus.txt',
    model_prefix='sentencepiece',
    vocab_size=32000,
    character_coverage=0.9995,
    train_extremely_large_corpus=True.    # 追加
)

special tokenの設定

上のコードだと<unk>(Unknown), <s>(BOS), </s>(EOS)の3つのトークンは設定されているのですが、その他のspecial tokenは設定されていません。これだといざ学習されたtokenizerを使用とした際、何かと不都合が生じる可能性があります。以下のようにすることで、この問題は解消できました。

import sentencepiece as spm


spm.SentencePieceTrainer.Train(
    input='corpus.txt',
    model_prefix='sentencepiece',
    vocab_size=32000,
    pad_id=3,
    pad_piece='[PAD]',
    user_defined_symbols=['[PAD]', '[CLS]', '[SEP]', '[MASK]']
)

試しに生成されたtokenizerをtransformerのインターフェースでインスタンス化してみたところ以下のように正しく設定されていることが確認できます。

from transformers import DebertaV2Tokenizer


tokenizer = DebertaV2Tokenizer(
    vocab_file=sentencepiece.model,
    bos_token='<s>',
    eos_token='</s>',
    unk_token='<unk>',
    pad_token='[PAD]',
    cls_token='[CLS]',
    sep_token='[SEP]',
    mask_token='[MASK]',
    extra_ids=0,
    additional_special_tokens=(),
    do_lower_case=True
)

unk_id = tokenizer.unk_token_id
bos_id = tokenizer.bos_token_id
eos_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
mask_id = tokenizer.mask_token_id
print(unk_id, tokenizer_myself3.convert_ids_to_tokens(unk_id))
print(bos_id, tokenizer_myself3.convert_ids_to_tokens(bos_id))
print(eos_id, tokenizer_myself3.convert_ids_to_tokens(eos_id))
print(pad_id, tokenizer_myself3.convert_ids_to_tokens(pad_id))
print(cls_id, tokenizer_myself3.convert_ids_to_tokens(cls_id))
print(sep_id, tokenizer_myself3.convert_ids_to_tokens(sep_id))
print(mask_id, tokenizer_myself3.convert_ids_to_tokens(mask_id))

結果

0 <unk>
1 <s>
2 </s>
3 [PAD]
4 [CLS]
5 [SEP]
6 [MASK]
4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?