36
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

huggingface/transformersのTrainerの使い方と挙動

Last updated at Posted at 2021-07-05

 Trainerは便利だが,中で何がどう動いているか分からないと怖くて使えないので,メモ。公式ドキュメントでの紹介はここ

基本的な使い方

from transformers import Trainer, TrainingArguments 

tokenizer=AutoTokenizer.from_pretrained('bert-base-uncased')
model=AutoModel.from_pretrained('bert-base-uncased')

args=TrainingArguments(
    output_dir='ディレクトリパス'
)

trainer=Trainer(
    args=args,
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

trainer.train()
trainer.evaluate()

まず,TrainerとTrainingArgumentsをimportする。また,モデルとトークナイザとしてAutoModelAutoTokenizerを例示するが,モデルはBertModelでもBertForSequenceClassificationでも何でも良いし,トークナイザも同様である。

次に,TrainingArgumentsを作成する。必須なのはoutput_dirのみである。ここに学習の途中経過が保存される。

続いて,Trainerを作成する。評価をしないならばeval_datasetは不要。訓練データ・評価データともに,torch.utils.data.Datasetである必要がある(Datasetについてはここが一番分かりやすい)。

最後に,train()で学習し,evaluate()で評価する。評価はしなくても良い。なお,evaluate()の結果はコンソールには表示されず,TensorBoardで確認する羽目になる。標準出力で確認したければ,Trainercompute_metricsにCallableを設定して,そのCallable(関数など)で何か表示させるようにプログラムを書く必要がある。

BERTはエポック数,バッチサイズ,学習率の3つをチューニングする必要がある(BERTの元論文)が,以下の通り設定できる。

パラメータの説明

Trainer

Trainerの引数でよく使うのは以下。

パラメタ 引数 説明
args TrainingArguments 下記参照。
train_dataset Dataset 訓練用データ。バッチ処理は自動的に行われる。また,データはランダムに並べ替えられる。
eval_dataset Dataset 評価用データ。バッチ処理は自動的に行われる。データは並べ替えられない。
tokenizer PreTrainedTokenizerBase ここでtokenizerを指定しない場合,後からモデルを保存してもtokenizerは保存されない。
compute_metrics Callable 評価指標を自由に設定できる。デフォルトだと結果は全てログに流れていくので,ここでprintしないと標準(エラー)出力には何も出てこない。
model PreTrainedModel BERTなど好きなモデルを指定する。BertModelを派生させて自分で作ったものでもOK。

TrainingArguments

TrainingArgumentsの引数でよく使うのは以下。

パラメタ 引数 説明
output_dir str ログや学習途中のモデルが出力される。必須。
do_train bool 指定する必要がない。
do_eval bool 指定する必要がない。
learning_rate float 5e-5などと指定する。デフォルトは5e-5
num_train_epochs float エポック数を指定する。デフォルトは3.0
per_device_train_batch_size int 訓練時のGPU1枚ごとのバッチサイズを指定する。デフォルトは8。
per_device_eval_batch_size int 評価時のGPU1枚ごとのバッチサイズを指定する。デフォルトは8。
evaluation_strategy str 1エポック毎に評価してほしい場合は'epoch'を指定する。こうすると画面に進捗だけでなく正解率などがエポック毎に表示されるようになる。デフォルトは'no'

GPUの数に応じた最終的なバッチサイズは以下で取得できる。

args.train_batch_size
args.eval_batch_size

挙動

データの読み込まれ方

Trainerクラス内での挙動について説明する。以下のget_train_dataloader()_get_train_sampler()はTrainerクラス内に定義されている。

train()時は,train_datasetが読み込まれるが,この際にget_train_dataloader()によってDataLoaderが読み込まれる。ここで,DataLoadersamplerとして,_get_train_sampler()内でtorch.utils.data.RandomSamplerが指定されている(ただし,TrainingArgumentsgroup_by_lengthをTrueにした場合は除く)。従って,訓練データはランダムにモデルへ投入される。

evaluate()時は,eval_datasetが読み込まれるが,この時にはRandomSamplerは使われない。

compute_metrics

※2022年4月7日にEvalPredictionが大幅に変更された(詳細)。なんと入力が使えるようになっている。

evaluate()から呼び出されるevaluation_loop()内にて,compute_metricsで指定したCallableが呼び出される。この時Callableに渡される引数は,EvalPrediction(predictions=all_preds, label_ids=all_labels)である。

EvalPrediction__getitem__は,次のように定義されている:

インデックス 戻り値 説明
0 Union[np.ndarray, Tuple[np.ndarray]] モデルの出力
1 Union[np.ndarray, Tuple[np.ndarray]] 正解ラベル
2 Optional[Union[np.ndarray, Tuple[np.ndarray]]] 入力。ただし,TrainingArgumentsinclude_inputs_for_metricsTrueにした場合のみ。していない場合はIndexError。
それ以外 IndexError これが出ることによってunpackでの代入ができる。

よって,

def compute_metrics(res):
    logits, labels=res

という書き方ができる。logitsには出力(ただし,損失関数にはかけられていないので,通常のコードであればsoftmax等は適用されていない),labelsにはラベルが入っているので,これらを比較してやれば良い。分類問題であれば,logitsのそれぞれのデータに対してargmaxをして推定ラベルを算出すればよい。いずれもnumpy.ndarrayに変換されていることに注意する(つまり,torch.Tensorの関数は使えず,numpyの関数を使うことになる)。

簡易的に結果を出力したければ,ここでprint()などすると良い。

compute_metricsは戻り値としてOptional[Dict[str, float]]が要求されているので,何か返すのであれば,辞書にして返す。例えば,

return {
    'precision': precision, # precisionはfloat
    'recall': recall # recallもfloat
}

といった感じである。
 ここで指定したkeyは,この後metric_key_prefixによって修飾される(evaluation_loop関数内の# Prefix all keys with metric_key_prefix + '_'コメントの後)。したがって,最終出力は'eval_precision': 0.15のようになる。

36
22
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
36
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?