#はじめに
fairseqでpreprosess後に各トークンにどのIDが割り当てられているのか確認する方法です.
なおfairseqのバージョンは0.9.0で行っています.
#結論
preprocess後に初期設定ではdata-bin
にバイナリデータが保存されているはずです.
その中にdict.en.txt
などdictから始まるファイルが通常存在していますので,今回それを使用します.
以下がtoken2idを取得するコードです.
>>> import fairseq
>>> dict = fairseq.data.Dictionary.load("dict.en.txt")
>>> token2id = dict.indices
>>> {key:token2id[key] for key in list(token2id.keys())[:15]}
{'<s>': 0, '<pad>': 1, '</s>': 2, '<unk>': 3, ',': 4, '.': 5, 'the': 6, 'of': 7, 'to': 8, 'and': 9, 'a': 10, 'in': 11, 'that': 12, '“': 13, '”': 14}
全てのtoken2idを表示すると大変なので,とりあえず15個出力してみました.
うまくトークンとIDの情報が対応されているのがわかりますね.
#考察
ここではどのようにIDが振られているのか考えてみましょう.
ちなみにdict.en.txt
の中身は以下のようになっています.
head -n 10 dict.en.txt
, 824248
. 801146
the 691194
of 428707
to 425800
and 326036
a 259768
in 255893
that 189164
“ 173250
ここで,左側はトークン,右側はテキスト内の出現頻度を表しています.
そのためtail dict.en.txt
を実行すると右側が1などばかりで低頻度語がファイルの下の行に記述されていることがわかります.
token2id
と比べてみると,4番以降のIDを持つトークンと行数の順番が一致しています.
そのため,<s>,<pad>,</s>,<unk>
の4つが予約語として割り当てられていて,それ以降は出現頻度順で割り当てられていることがわかります.
このことからdict.en.txt
にはトークンとその出現頻度のみ記述すればよく,あとはソースコード側で頻度順でIDを振っていけば良いということになります.
ここまでの流れが
fairseq/data/dictionary.py
に長ったらしく書かれています.
コードを読みたかったら
https://github.com/pytorch/fairseq/blob/main/fairseq/data/dictionary.py
を参照してください.
以上のことを踏まえてtoken2idを実装すると以下のようになります.
ここでは簡略化のため,なにも例外処理などしていません.
defined_tag = ['<s>','<pad>','</s>','<unk>']
token2id = dict()
ids = 0
for token in defined_tag:
token2id[token] = ids
ids+=1
with open('dict.en.txt') as f:
for line in f:
token,_ = line.rstrip().split()
token2id[token] = ids
ids+=1
print({key:token2id[key] for key in list(token2id.keys())[:15]})
以上により簡単にtoken2idを自作することができますが,例外処理など見落とす場合があるので,fairseq.data.Dictionary.load()
を使うことをおすすめします.