ずっとinteractive.pyを改変してやってきましたが、公式にいい関数がありました。
以前の方法
fairseqのinteractiveをクラス化する
コード
from fairseq.models.transformer import TransformerModel
class Interactive:
def __init__(self, spm_path, data_path, checkpoint_path, checkpoint_name):
#同時処理文数
self.num = 32
self.ltos = TransformerModel.from_pretrained(
checkpoint_path,
checkpoint_file=checkpoint_name,
data_name_or_path=data_path,
bpe='sentencepiece',
sentencepiece_model=spm_path,
no_repeat_ngram_size=2
)
def inference(self, texts: list):
result = []
n = self.num
for t in [texts[i*n:(i+1)*n] for i in range(len(texts))]:
result += self.ltos.translate(t)
return [r.replace("_", " ") for r in result]
解説
- 基本的には
fairseq.models.transformer.from_pretrained.translate
で推論を行うだけです。 - 100文くらいを同時処理しようとするとメモリを食うせいか遅くなるので、複数文に分けて与えています。
- CPUは搭載数の半分しか使わないみたいです。
今までのコードの5倍くらい早くなりました。