26
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

はじめに

faiseqを使ってFusion-in-decoderを実装しようとしたら予想以上に苦労したのでfairseqにあるクラスファイルについて細かく説明する.
dataloaderはどこにあるのか, そもそものデータの読み込みはどこで行っているのか, モデルはどこで定義されているのか, などコードリーディングの負荷を下げるためのメモ程度に書いていく.

目次

faieseqとは

翻訳, 要約, 言語モデリング, その他のテキスト生成タスクのためにモデルをトレーニングすることができる自然言語処理向けのツールキットのこと.
pytorchベースで作成されているので一部分のみを使うこと, モデルやタスクを自分で定義してトレーニングや評価することが可能
fairseq用のクラスを定義して用いてるので細かい部分で融通が利かないのが難点.

簡単な使い方

インストールからモデルの学習, テストによる評価までを簡単に書いていく.
細かい説明はドキュメントがあるので省略する.

https://fairseq.readthedocs.io/en/latest/#
##1. fairseqのインストール
gitからcloneすることでインストールを行う.
pipのみでインストールすることもできるが, gitからcloneすることを推奨.

git clone https://github.com/pytorch/fairseq
cd fairseq 
pip install --editable ./

上記コマンドでは最新版がインストールされることに注意

インストールが正常に終了すればコマンドプロンプトでfairseqのコマンドが使用できる
command not foundが出たらインストールできていないのでインストールからやり直す必要がある.

C:User> fairseq-train

##2. データセットの作成
fairseqで処理できるにデータセットを作成する必要がある.
データセットはsrcとtgtの二つに分けてtrain.src, train.tgtのように作成する.

train.src
モデルへの入力文A
モデルへの入力文B
モデルへの入力文C
train.tgt
モデルの出力A
モデルの出力B
モデルの出力C

train.srcのN行目とtrain.tgtのN行目が一つのペアとなる

preprocessは最低限以下のコマンドのように指定すればいい.

  • pref : 各ファイルのpath
  • destdir : 作成したデータセットの保存先
  • task : 翻訳, 分類,などモデルに学習させるタスクの種類
fairseq-preprocess 
   --trainpref dataset/raw/train \
   --validpref dataset/raw/valid \
   --testpref dataset/raw/test \
   --source-lang src --target-lang tgt \
   --destdir dataset/tokenized \
   --srcdict dataset/dict.src.txt \
   --tgtdict dataset/dict.dst.txt \
   --task translation 

##3. モデルのトレーニング
今回は1つのGPUを使って学習する場合について記載する.
teeやリダイレクトを使ってログを取っておくことを推奨.

  • arch : 使用するモデルのアーキテクチャを指定
  • restore-file : 事前学習済みのものをfinetuneするならファイルを指定
  • lr : learning rateの初期値, default=0.25となってるから指定することを推奨
  • max-epoch : epoch数の指定, 指定しないと永遠に学習を続けるので適当に値を入れておくといい
fairseq-train dataset/tokenized/ \
  --task translation \
  --arch transformer \
  --restore-file model/model.pt \
  --reset-optimizer --reset-dataloader --reset-lr-scheduler \
  --source-lang src --target-lang dst \
  --max-epoch 20 \
  --lr 1e-4 \
| tee train.log

##4. モデルの評価
評価についても1つのGPUを用いて行う.
モデルのトレーニングと同様にtee, リダイレクトを用いることを推奨.

  • path : 評価するモデルの相対パス
  • gen-subset : 使用するデータセットの指定(ex. train, valid, test)
fairseq-generate dataset/tokenized
  --path checkpoint/checkpoint_best.pt \
  --task translation \
  --gen-subset test \
  -s src -t tgt \ 
> result.txt

各種ファイルの役割

fairseqはpytorchをベースにしているため拡張性は高いが, 独自のクラスを定義するため一部のみを利用したり, 細かい部分で書き換えるときに参照すべきファイルが多い.
ここではtasks, models, criterions, dataなど各フォルダの中でメインとなる部分について説明していく.

##1. ディレクトリ構造について(一部省略)
基本的には下記のファイルとディレクトリについて理解しておけば十分.
特にtasksは各クラスとの関係性が強いのでtasksを中心に各ファイルについて詳細を説明する.
fairseq_cliはfairseqディレクトリの処理内容が分かれば理解しやすいので省略する

fairseq/
 ├ fairseq
   ├ criterions #lossの算出方法などのファイルがあるディレクトリ
   ├ data   #dataloader, ファイルからデータの読み込みなどを行う
   ├ models  #モデルの定義ファイルのあるディレクトリ
   └ tasks   #タスク(翻訳, 分類, etc..)の定義のあるディレクトリ
 └ fairseq_cli
   ├ preprocess.py #データセット作成時に実行したプログラム
   ├ train.py   #モデルのトレーニング時に実行したプログラム
   └ generate.py #モデルの評価時に実行したプログラム

##2. tasks
taskクラスを基にデータセットの読み込み, モデルの読み込みなど行う.
fairseq_task.pyにあるFairseqTaskクラスがtaskの基準となるため, FairseqTaskクラスを継承すればtaskを個人で定義することが可能.
load_dataset関数内でload_langpair_dataset関数が呼び出されることでデータの読み込みを行うのでここは詳しく確認する必要はない

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

taskを個人で定義する場合は以下のリンクを参照すればいい.
今回はdataloader, datasetの読み込み時についてなので省略する.

##3. data
language_pair_dataset.py, indexed_dataset.py, prepend_token_dataset.pyの3つについて定義されている関数や, 各ファイルの役割について説明する
###language_pair_dataset.py
indexed_dataset.pyで一つ一つのファイルから読み込みを行うが実際にデータセットとして管理するのはlanguage_pair_datasetで行われる.
train.srcとtrain.tgtの2つのファイルを繋げるための役割を担っている.

class LanguagePairDataset(FairseqDataset):
   def __init__(...):
    ...省略...

データの取り出しは__getitem()__を使っている. しかし, 実際にモデルの入力として渡されるのはここの返却値である__example__ではないこと注意が必要

class LanguagePairDataset(FairseqDataset):
  def __getitem__(self, index):
     tgt_item = self.tgt[index] if self.tgt is not None else None
     src_item = self.src[index]

       ...省略...

     example = {
            "id": index,
            "source": src_item,
            "target": tgt_item,
           }
     retrun example

dataloaderを定義する際にlanguage_pair_dataset.pyにある__collate()__が呼ばれるように設定される. 上記の__example__を受け取り, モデルへの入力にするために変換する関数として作用する. 具体的には__batch__にある__net_input__がモデルへの入力として用いられる.

def collate(...):
 
  ...省略...  

  batch = {
        "id": id,
        "nsentences": len(samples),
        "ntokens": ntokens,
        "net_input": {
            "src_tokens": src_tokens,
            "src_lengths": src_lengths,
        },
        "target": target,
    }

  ...省略...

  return batch

###indexed_dataset.py
ここでfairseq-preprocess時に作成したファイルを参照し, datasetとして読み込む
dataset-implによって作成されるデータセットの型が変わるが, fairseq-preprocess時にdataset-implを指定しない場合はIndexedRawTextDatasetになるため基本的にはここを見ればいい.

class IndexedRawTextDataset(FairseqDataset):
    def read_data(self, path, dictionary):
        ...省略...

ここでファイルからデータを読み込んでいる, train.srcとtrain.tgtのどちらもこの関数を用いるのでプログラムを書き換える際には注意が必要.

###prepend_token_dataset.py
データセットからデータが読み込まれる際はこのファイルにある__getitem()__などが呼ばれる.
データセットの形を変えた場合この部分も同様に変更する必要がある.
idxはデータセットから取り出すデータのidでありint型.

def __getitem__(self, idx):
        item = self.dataset[idx]
        if self.token is not None:
            item = torch.cat([item.new([self.token]), item])
        return 

##4. criterions
lossの算出方法が定義されているフォルダ, defaultではcross_entropy.pyを呼び出される.
modelから__foward()__を呼び出すのではなく, criterionsを通してmodelの__foward()__が呼ばれるので処理順で確認するならばcriterionsを見ることを推奨する.
__sample__はlanguage_pair_datasetの__collate()__で変換された__batch__に対応している.

class CrossEntropyCriterion(FairseqCriterion):
    def __init__(self, task, sentence_avg):
        super().__init__(task)
        self.sentence_avg = sentence_avg

    def forward(self, model, sample, reduce=True):
        net_output = model(**sample["net_input"]) 

        ...省略...

##5. model
modelにはfairseq_encoder.pyやtransformer_lm.pyなどモデルを定義する際に用いられるモジュールと, BART, RoBERTa, Wav2Vecなどの定義済みモデルがある. 定義済みモデルはフォルダで分けられているので, 同様にモデルを作成することが可能.

class TransformerModelBase(FairseqEncoderDecoderModel):
    """
    Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
    <https://arxiv.org/abs/1706.03762>`_.
    """

    def forward(
        self,
        src_tokens,
        src_lengths,
        prev_output_tokens,
        return_all_hiddens: bool = True,
        features_only: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
    ):

       encoder_out = self.encoder(
          src_tokens,src_lengths=src_lengths,return_all_hiddens=return_all_hiddens)
       decoder_out = self.decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=features_only,
            alignment_layer=alignment_layer,
            alignment_heads=alignment_heads,
            src_lengths=src_lengths,
            return_all_hiddens=return_all_hiddens,
           )
       return decoder_out

各ファイルにある__foward()__を確認すれば処理の内容が分かる. 上記はtrnasformerなのでencoderとdecoderから構成されており, encoderの出力とdecoderの入力をつなげている.
関数にある引数はcritrionsで呼んでいる__sample['net_input']__を受け取るようなっているので, 実質は__collate()__で返却されるbatchのnet_inputと同じ値となる.

おわりに

基本的にはデータの読み込みからモデルの出力を流れを追う際に見るべきファイルについて説明したが, 細かい処理の内容については複雑になるので省略したので悪しからず.
Fusion-in-Decoderの実装手順については別記事?で書くかもしれない.
fairseq_trainのコマンドを追うように説明したので処理を理解する助けになることを祈ってます.

26
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
26
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?