LoginSignup
0
0

sentence-transformesでのcallback関数の書き方

Posted at

Sentence-TransformersでEmbedding Modelをトレーニングする際に、どのようにCallback関数を書いたらいいかの日本語の記事がなかったので、備忘録を兼ねて残しておきます。

Sentence-TransformersでEmbedding Modelをトレーニングするコード自体はこちらを参考にしてください。LlamaIndexによる埋め込みモデルのファインチューニングを試す
また、こちらは公式のDocumentです。SentenceTransformer

今回はEvaluatorから提供されるスコアが最高値になった際にモデルを保存するCallback関数を例に書きます。
best_model自体は既にfitのオプションとして実装されているのでご注意。

Callback関数

best_score = -999999
def custom_callback(score, epoch, steps):
    global best_score
    print(f"score : {score}, epoch : {epoch + 1}, steps :{steps}  ")
    if best_score <= score:
        model.save("/output/")

model.fit()

model.fit(
    train_objectives=[(loader, loss)],
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path="results",
    show_progress_bar=True,
    evaluator=InformationRetrievalEvaluator, 
    evaluation_steps=1,
    callback=callback,
)

このコードでは、InformationRetrievalEvaluatorをEvaluatorとして採用しています。
Evaluatorを使った際は、validation datasetに対するEvaluatorのスコア(lossではないので注意), 現在のEpoch数, 現在のstep数が引数として、callback関数に渡されます。

現状のsentence-transformersだとlossの数値をcallback関数に直接渡せないので、callbackをクラスから定義する必要があるのですが、それはまた別の機会とします。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0