pytorch_pretrained_bertを使っていますか?今はpytorch-transformersと名前を変えたこちらのライブラリですが、PyTorchでBERTを扱おうとしたときにはこれを利用している方は多いのではないでしょうか。
京大から事前学習モデルが配布されているため、日本語のデータでもタスクに応じた学習をすることが可能なのですが、ちょっとした落とし穴があったため記事にしておきます。ただし、今のバージョンだと修正されており、ちょっと古いバージョンを使っている人向けです
。私の方で問題が発生したときにはpytorch_pretrained_bertのバージョンは0.6.2でした。
問題発生
京大の事前学習モデルを使って自分のやりたいタスクの日本語モデルの学習後に、モデルの重みやvocab.txtなどの情報を保存しました。そして再度モデルを読み込み、予測をさせてみると・・・。
学習中に比べて全然うまくいってない!
ちょっと違うとかではないし、訓練データを予測させてもうまくいきません。
PyTorchのモデルの保存か読み込みでどこかミスしたのかと考えましたが、どうやら違うようです・・・。
原因と解決方法
原因はvocab.txtの読み込み
自分で学習した後に保存したvocab.txtを眺めていると、事前学習モデルのvocab.txtと行数が異なることに気づきます。次のような感じになります。
︙
など
3
−
この
ない
ため
︙
︙
など
3
−
この
ない
ため
︙
事前学習モデルのvocab.txtには46行目に空行、2451行目にnon-breaking spaceがあるのですが、実はvocab.txtをBERTのtokenizerで読み込むときにこれらを同一視してしまうような仕組みになっていました。同一視されることでvocab.txtの情報が1つ分なかったことにされてしまいます。学習後に保存したvocab.txtではこれが反映されて保存されたため、おかしくなっていたのでした。
以下のtokenization.pyのload_vocabにおけるstripの呼び出しとvocab[token]への代入が該当箇所になりますが、空行もnon-breaking spaceもstripすることで空文字になってしまっています。tokenに対応するindexを逐次代入していっていますので、vocab.txtの行数が異なるとこのindexに違いが生じてしまいます。
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = reader.readline()
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
解決方法
上記コードのstripの呼び出しの行をtoken = token.rstrip("\n")とすることで、空行とnon-breaking spaceを別物として扱うことができるようになります。結果として事前学習のモデルでのvocab.txtと学習後に保存したvocab.txtは一緒になります。これは今のpytorch-transformersでは対応済みです。
この修正をおこなった後では、ちゃんと学習中の予測結果とモデルを読み込んだ場合の予測結果は一致しました!やったね