4
2

More than 1 year has passed since last update.

transformersのTrainerでもにょった

Posted at

はじめに

payanottyです。
transformersのTrainerを使ってて軽くもにょったことがあったので書いてみたいと思います。

もにょったこと

transformersはHugging Face社が提供している深層学習ライブラリです。
自然言語処理に特化していて、pythonで自然言語処理をするなら定番のライブラリになっています。

model hubには無数に事前学習済みモデルが登録されており、これらの中から好きなモデルをロードしてきてファインチューニングすることができます。

ファインチューニングの際に自分で学習処理を書くのもよいのですが、transformersの中で提供されているTrainerクラスを使うととても簡単です。

from transformers import TrainingArguments
from transformers import Trainer

training_args = TrainingArguments()

trainer = Trainer(
    model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset
)

Trainerクラスで便利な機能として、Callbackというものがあります。
これは学習中に行いたい処理をCallbackクラスとして実装しておき、Trainerの引数にそのCallbackのインスタンスを指定することができるというものです。

class MyCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"

    def on_train_begin(self, args, state, control, **kwargs):
        print("Starting training")

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[MyCallback]  # We can either pass the callback class this way or an instance of it (MyCallback())
)

筆者はTrainerクラスを使って事前学習を回そうとしていて、学習の経過をMLflow serverに記録したいと考えていました。
少し調べてみるとMLflowCallbackはすでにあるということで、それを継承して少し処理を書き換えたCallbackを実装すれば良さそうだと思いました。

class MyMLflowCallback(MLflowCallback):
    """
    なんかしら処理
    """

こんな感じでCallbackを入れて実験を回してみました。
MLflow UIで実験経過を確認してみると、なるほど、Callbackに実装したとおり、きちんとメトリクスが記録され、アーティファクト保存も問題なさそうです。
でもよく見ると、experimentが2つできてました。
そのうちの1つは自分で設定したものなので良いのですが、もう1つの方は何故か勝手にできていて、自分が設定した覚えのないメトリクスが記録され続けていました。

なんで勝手にexperimentできてんのきもきもきもきも、もにょ~~~~?ってなりました。

なんだったのか

実はTrainerにはデフォルトでMLflowCallbackが指定されており、MLflowがインストールされている場合は勝手にこのMLflowCallbackの処理が行われる、という仕様になっていました。
なので、デフォルトのCallbackが作ったexperimentと、自作のCallbackが作ったexperimentとで、それぞれが独立して走っていたというわけでした。
今回のように自作のMLflowCallbackを使いたい場合は、以下のようにしてデフォルトのCallbackをTrainerから削除することを推奨します。

# デフォルト消して
trainer.remove_callback(MLflowCallback)
# 自作のを入れる
trainer.add_callback(MyMLflowCallback)

MLflow serverにアーティファクトを送りたかったり、モデルレジストリにモデル登録したかったり、デフォルトのMLflowCallbackではちょっと足りないことをしたいことはあると思うので、このことは覚えておくともにょらなくて済むと思います。

もにょるってなんだよ

わた天を見ましょう。

image.png
©椋木ななつ・一迅社/わたてん製作委員会

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2