目的
- sentence-transformersとtripletlossを用いて学習を行うコードのversion upに伴う微修正
参考
code
- 記載時のversionは
sentence-transformers 2.7.0
- tripletなdatasetを使用しているのは参考の第9回の方なので,そちらのコードを調整する
model load
import transformers
transformers.BertTokenizer = transformers.BertJapaneseTokenizer
from sentence_transformers import SentenceTransformer
from sentence_transformers import models
from sentence_transformers.losses import TripletDistanceMetric, TripletLoss
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.readers import TripletReader
from sentence_transformers.datasets import SentencesDataset
from torch.utils.data import DataLoader
transformer = models.Transformers('cl-tohoku/bert-base-japanese-whole-word-masking')
pooling = models.Pooling(
transformer.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False
)
model = SentenceTransformer(modules=[transformer, pooling])
models.BERT => models.Transformers
に修正
TripletEvaluatorの調整
dev_data = SentencesDataset(triplet_reader.get_examples('triplet_dev.tsv'), model=model)
dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=BATCH_SIZE)
evaluator = TripletEvaluator(dev_dataloader)
TripletEvaluato
rの引数にdataloaderを与えているが,ver 0.3.3あたりで(anchor[List], positive[List], negative[List])
で与えるように修正されている.
そのため,anchor,positive,negativeのテキストデータをlistに入れなおす必要がある.
from sentence_transformers.readers import TripletReader
anchor_list = []
positive_list = []
negative_list = []
# crate [each anchor, positive, negative] text List
for example in triplet_reader.get_examples('triplet_dev.tsv')
anchor_list.append(example.texts[0])
positive_list.append(example.texts[1])
negative_list.append(example.texts[2])
dev_data = SentencesDataset(triplet_reader.get_examples('triplet_dev.tsv'), model=model)
dev_dataloader = DataLoader(dev_data, shuffle=False, batch_size=BATCH_SIZE)
evaluator = TripletEvaluator(anchor_list, positive_list, negative_list)
これで modelを学習できるようになる!
あんまり綺麗じゃない & いい方法他にもありそうなのですが,とりあえず動かすというだけならまぁこれでいいでしょう
dataLoaderを与える場合の方が綺麗な実装な気がするんですが,なんでListで与える形式になったんだろう?? 謎