Sentence-TransformersでEmbedding Modelをトレーニングする際に、どのようにCallback関数を書いたらいいかの日本語の記事がなかったので、備忘録を兼ねて残しておきます。
Sentence-TransformersでEmbedding Modelをトレーニングするコード自体はこちらを参考にしてください。LlamaIndexによる埋め込みモデルのファインチューニングを試す
また、こちらは公式のDocumentです。SentenceTransformer
今回はEvaluatorから提供されるスコアが最高値になった際にモデルを保存するCallback関数を例に書きます。
best_model自体は既にfitのオプションとして実装されているのでご注意。
Callback関数
best_score = -999999
def custom_callback(score, epoch, steps):
global best_score
print(f"score : {score}, epoch : {epoch + 1}, steps :{steps} ")
if best_score <= score:
model.save("/output/")
model.fit()
model.fit(
train_objectives=[(loader, loss)],
epochs=EPOCHS,
warmup_steps=warmup_steps,
output_path="results",
show_progress_bar=True,
evaluator=InformationRetrievalEvaluator,
evaluation_steps=1,
callback=callback,
)
このコードでは、InformationRetrievalEvaluatorをEvaluatorとして採用しています。
Evaluatorを使った際は、validation datasetに対するEvaluatorのスコア(lossではないので注意), 現在のEpoch数, 現在のstep数が引数として、callback関数に渡されます。
現状のsentence-transformersだとlossの数値をcallback関数に直接渡せないので、callbackをクラスから定義する必要があるのですが、それはまた別の機会とします。