LoginSignup
13
4

More than 3 years have passed since last update.

fairseqで自分でトレーニングしたTransformerモデルをロードする

Posted at

はじめに

機械翻訳のTransformerモデルをトレーニングする機会があり,Pytorchベースのfairseqを使ったんですが,アプリ用のコードでモデルのロードにハマってしまいました.備忘録のために記事を書きます.

fairseqとは

Facebookの人工知能研究チームが開発している,機械翻訳用のフレームワークです.Facebookが開発元ということもあり,Pytorchがベースになっています.最近はHuggingfaceのTransformersが人気でTransformerモデルを扱うならPytorchだよね,ということもあり,こちらをフレームワークとして選びました.
その他の機械翻訳フレームワークとしては,MarianNMT,OpenNMT(こちらもPytorchベース)などがあります.基本的な機能はどのフレームワークも大差ない印象ですが,論文実装のコードはfairseqが選ばれている場合が多いです.

モデルのトレーニング方法

fairseqはドキュメントが充実しているため,チュートリアル通りにやれば,自分でも機械翻訳モデルをトレーニングすることができます.チュートリアルは英語のみですが,日本語のドキュメントが不足しているのは,他の機械翻訳フレームワークでも同様です.

トレーニングのためのコード(fairseqの公式ドキュメントから)
> mkdir -p checkpoints/fconv
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
    --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
    --arch fconv_iwslt_de_en --save-dir checkpoints/fconv

トレーニングコードのsave-dirは学習済みモデルの格納フォルダですが,このフォルダには,各エポックにおける学習済みモデル(checkpoint(n).pt),精度が一番よかった学習済みモデル(checkpoint_best.pt),最終エポックにおける学習済みモデル(checkpoint_last.pt)が格納されます.

checkpoints_fconv
$ ls -alF
checkpoint1.pt
checkpoint2.pt
...
checkpoint99.pt
checkpoint_bes.pt
checkpoint_last.pt

推論の際は,checkpoint_best.ptを読み込むことになるかと思います.

インタラクティブモードでの推論

fairseqの公式ドキュメントには,シェル上のインタラクティブモードでの推論についてはチュートリアルがありますが,Pythonのコード内でモデルをロードする方法などは紹介されていません.

fairseq-interactiveによ推論(fairseqの公式ドキュメントから)
> MODEL_DIR=wmt14.en-fr.fconv-py
> fairseq-interactive \
    --path $MODEL_DIR/model.pt $MODEL_DIR \
    --beam 5 --source-lang en --target-lang fr \
    --tokenizer moses \
    --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
Why is it rare to discover new marine mammal species?
S-0     Why is it rare to discover new marine mam@@ mal species ?
H-0     -0.0643349438905716     Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
P-0     -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015

Pythonのコード内で学習済みモデルをロードする方法

modelsクラスを使うのが正解でした.インタラクティブモードの場合と異なり,いくつかの引数は自動でロードされるようです.

>>> from fairseq.models.transformer import TransformerModel
>>> model = TransformerModel.from_pretrained('wmt14.en-fr.fconv-py', 'model.pt', 'wmt14.en-fr.fconv-py')
>>> text = 'Why is it rare to discover new marine mam@@ mal species ?'
>>> model.translate(text, beam=5)
'Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?'

まとめ

fairseqは日本語ドキュメントが少ないですが,HuggingfaceのTransformersを使っている人であれば使いやすいですし,Pytorchベースということにも将来性を感じます.
論文実装のコードも多いですし,今後も触っていきたいフレームワークですね.

13
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
4