はじめに
機械翻訳のTransformerモデルをトレーニングする機会があり,Pytorchベースのfairseqを使ったんですが,アプリ用のコードでモデルのロードにハマってしまいました.備忘録のために記事を書きます.
fairseqとは
Facebookの人工知能研究チームが開発している,機械翻訳用のフレームワークです.Facebookが開発元ということもあり,Pytorchがベースになっています.最近はHuggingfaceのTransformersが人気でTransformerモデルを扱うならPytorchだよね,ということもあり,こちらをフレームワークとして選びました.
その他の機械翻訳フレームワークとしては,MarianNMT,OpenNMT(こちらもPytorchベース)などがあります.基本的な機能はどのフレームワークも大差ない印象ですが,論文実装のコードはfairseqが選ばれている場合が多いです.
モデルのトレーニング方法
fairseqはドキュメントが充実しているため,チュートリアル通りにやれば,自分でも機械翻訳モデルをトレーニングすることができます.チュートリアルは英語のみですが,日本語のドキュメントが不足しているのは,他の機械翻訳フレームワークでも同様です.
> mkdir -p checkpoints/fconv
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
--lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
トレーニングコードのsave-dir
は学習済みモデルの格納フォルダですが,このフォルダには,各エポックにおける学習済みモデル(checkpoint(n).pt
),精度が一番よかった学習済みモデル(checkpoint_best.pt
),最終エポックにおける学習済みモデル(checkpoint_last.pt
)が格納されます.
$ ls -alF
checkpoint1.pt
checkpoint2.pt
...
checkpoint99.pt
checkpoint_bes.pt
checkpoint_last.pt
推論の際は,checkpoint_best.pt
を読み込むことになるかと思います.
インタラクティブモードでの推論
fairseqの公式ドキュメントには,シェル上のインタラクティブモードでの推論についてはチュートリアルがありますが,Pythonのコード内でモデルをロードする方法などは紹介されていません.
> MODEL_DIR=wmt14.en-fr.fconv-py
> fairseq-interactive \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --source-lang en --target-lang fr \
--tokenizer moses \
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
Why is it rare to discover new marine mammal species?
S-0 Why is it rare to discover new marine mam@@ mal species ?
H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
Pythonのコード内で学習済みモデルをロードする方法
modelsクラスを使うのが正解でした.インタラクティブモードの場合と異なり,いくつかの引数は自動でロードされるようです.
>>> from fairseq.models.transformer import TransformerModel
>>> model = TransformerModel.from_pretrained('wmt14.en-fr.fconv-py', 'model.pt', 'wmt14.en-fr.fconv-py')
>>> text = 'Why is it rare to discover new marine mam@@ mal species ?'
>>> model.translate(text, beam=5)
'Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?'
まとめ
fairseqは日本語ドキュメントが少ないですが,HuggingfaceのTransformersを使っている人であれば使いやすいですし,Pytorchベースということにも将来性を感じます.
論文実装のコードも多いですし,今後も触っていきたいフレームワークですね.