3
1

More than 3 years have passed since last update.

fairseqで推論するときのクラス

Posted at

ずっとinteractive.pyを改変してやってきましたが、公式にいい関数がありました。

以前の方法
fairseqのinteractiveをクラス化する

コード

from fairseq.models.transformer import TransformerModel

class Interactive:
    def __init__(self, spm_path, data_path, checkpoint_path, checkpoint_name):
        #同時処理文数
        self.num = 32

        self.ltos = TransformerModel.from_pretrained(
            checkpoint_path,
            checkpoint_file=checkpoint_name,
            data_name_or_path=data_path,
            bpe='sentencepiece',
            sentencepiece_model=spm_path,
            no_repeat_ngram_size=2
            )

    def inference(self, texts: list):
        result = []
        n = self.num
        for t in [texts[i*n:(i+1)*n] for i in range(len(texts))]:
            result += self.ltos.translate(t)
        return [r.replace("_", " ") for r in result]

解説

  • 基本的には fairseq.models.transformer.from_pretrained.translate で推論を行うだけです。
  • 100文くらいを同時処理しようとするとメモリを食うせいか遅くなるので、複数文に分けて与えています。
  • CPUは搭載数の半分しか使わないみたいです。

今までのコードの5倍くらい早くなりました。

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1