公式のレポジトリにもいくつか説明箇所はあるのですが記法などが統一されていなかったりして少し調べるのに戸惑ったので、こちらのブログに書き残しておくことにしました。
なお、公式のレポジトリは以下になります。
Sentencepiceの学習について
Pythonインターフェースでのコードですが、非常にシンプルです。ググってよく出てくるサンプルだと例えば以下のような感じだと思います。
import sentencepiece as spm
# モデルの学習
spm.SentencePieceTrainer.Train(
input='corpus.txt',
model_prefix='sentencepiece',
vocab_size=32000,
character_coverage=0.9995
)
# こんな感じに書くこともできる
spm.SentencePieceTrainer.Train('--input=corpus.txt --model_prefix=sentencepiece --vocab_size=32000 character_coverage=0.9995')
オーバーフローを防ぐ
(多分、、、)テキストファイルを一気にロードするため、大きなコーパスを学習させる際はCPUのコア数を増やす必要があります。その際以下の引数を追加することで、オーバーフローを防ぐことが可能です。
import sentencepiece as spm
spm.SentencePieceTrainer.Train(
input='corpus.txt',
model_prefix='sentencepiece',
vocab_size=32000,
character_coverage=0.9995,
train_extremely_large_corpus=True. # 追加
)
special tokenの設定
上のコードだと<unk>
(Unknown), <s>
(BOS), </s>
(EOS)の3つのトークンは設定されているのですが、その他のspecial tokenは設定されていません。これだといざ学習されたtokenizerを使用とした際、何かと不都合が生じる可能性があります。以下のようにすることで、この問題は解消できました。
import sentencepiece as spm
spm.SentencePieceTrainer.Train(
input='corpus.txt',
model_prefix='sentencepiece',
vocab_size=32000,
pad_id=3,
pad_piece='[PAD]',
user_defined_symbols=['[PAD]', '[CLS]', '[SEP]', '[MASK]']
)
試しに生成されたtokenizerをtransformerのインターフェースでインスタンス化してみたところ以下のように正しく設定されていることが確認できます。
from transformers import DebertaV2Tokenizer
tokenizer = DebertaV2Tokenizer(
vocab_file=sentencepiece.model,
bos_token='<s>',
eos_token='</s>',
unk_token='<unk>',
pad_token='[PAD]',
cls_token='[CLS]',
sep_token='[SEP]',
mask_token='[MASK]',
extra_ids=0,
additional_special_tokens=(),
do_lower_case=True
)
unk_id = tokenizer.unk_token_id
bos_id = tokenizer.bos_token_id
eos_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
mask_id = tokenizer.mask_token_id
print(unk_id, tokenizer_myself3.convert_ids_to_tokens(unk_id))
print(bos_id, tokenizer_myself3.convert_ids_to_tokens(bos_id))
print(eos_id, tokenizer_myself3.convert_ids_to_tokens(eos_id))
print(pad_id, tokenizer_myself3.convert_ids_to_tokens(pad_id))
print(cls_id, tokenizer_myself3.convert_ids_to_tokens(cls_id))
print(sep_id, tokenizer_myself3.convert_ids_to_tokens(sep_id))
print(mask_id, tokenizer_myself3.convert_ids_to_tokens(mask_id))
結果
0 <unk>
1 <s>
2 </s>
3 [PAD]
4 [CLS]
5 [SEP]
6 [MASK]