なんの記事?
名寄せのデータセットを用いて、名寄せのタスクに取り組んだので、備忘録として残します。
名寄せとは?
https://buildersbox.corp-sansan.com/entry/2020/03/10/110000
上記記事にもあるとおり、2つの文字列が同じ実体(エンティティ)を指すかどうか?を予測するタスクです。
エンティティ・リンキングとは異なり実体のエンティティへの結びつけは今回のタスクのスコープからは外しましたが、実際には特定の実体への紐付け(=エンティティ・リンキング)も含める場合もあります。
モデルおよび使用した特徴量
上記が組んだモデルになります。
今回はInferSentや他のエンティティ・リンキングで使用される特徴量を考慮し、以下の4つを特徴量として使用しました。
それぞれ見ていきましょう。
メンション同士の結合による埋め込み
モデル図にも示したように、BERTに各メンションを結合したトークンを与え、最終層で得られた[CLS]トークンでの埋め込みを特徴として使用しました。
トークナイズ部(AllenNLPを使用)
l_plus_r = [Token('[CLS]')]
l_plus_r += [Token(split_token) for split_token in self.custom_tokenizer_class.tokenize(txt=data['l'])][:self.config.max_mention_length]
l_plus_r += [Token(BOND_TOKEN)]
l_plus_r += [Token(split_token) for split_token in self.custom_tokenizer_class.tokenize(txt=data['r'])][:self.config.max_mention_length]
l_plus_r += [Token('[SEP]')]
context_field = TextField(l_tokenized, self.token_indexers)
[CLS]トークンを用いることで、同じエンティティかどうかを予測する2メンション同士の相互attentionまで含めた埋め込みを特徴として用いました。
2つのメンションの埋め込みの絶対値差
InferSentで利用されていたので、こちらでも使用してみました。同じメンションであれば近く、別のメンションであれば遠くのベクトルへとエンコーディングされることを期待しています。
Levenstein Distance
周辺文脈を使用できないので、今回は編集距離もモデルの特徴量として用いました。
2つのメンションのサブワード共通数
各メンションをサブワードに分割した際の共通サブワード数を特徴として入れました。
(Pdb) data['r']
'スペインワールドカップ'
(Pdb) p self.custom_tokenizer_class.tokenize(txt=data['r'])
['ス', '##ヘイ', '##ン', '##ワール', '##ト', '##カ', '##ッフ']
(Pdb) data['l']
'FIFAワールドカップ'
(Pdb) p self.custom_tokenizer_class.tokenize(txt=data['l'])
['f', '##if', '##a', '##ワール', '##ト', '##カ', '##ッフ']
上記だと4になります。
モデル実装
AllenNLPを用いて実装しました。
モデル部分
import torch
import torch.nn as nn
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.models import Model
from overrides import overrides
from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy
import copy
class ResolutionLabelClassifier(Model):
def __init__(self, args,
mention_encoder: Seq2VecEncoder,
vocab):
super().__init__(vocab)
self.args = args
self.mention_encoder = mention_encoder
self.accuracy = BooleanAccuracy()
self.BCEWloss = nn.BCEWithLogitsLoss()
self.mesloss = nn.MSELoss()
self.istrainflag = 1
self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
self.linear_for_cat = nn.Linear(self.mention_encoder.get_output_dim()* 2 + 2, 1)
self.linear_for_bond = nn.Linear(self.mention_encoder.get_output_dim(), 1)
def forward(self, l, r, label, mention_uniq_id, l_plus_r, lev, subword_match_num):
l_mention = self.mention_encoder(l)
r_mention = self.mention_encoder(r)
l_and_r_cross = self.mention_encoder(l_plus_r)
if self.args.scoring_function_for_model == 'sep':
scores = self.linear_for_cat(torch.cat((l_mention, r_mention), dim=1))
else:
cated = torch.cat((l_and_r_cross, torch.abs(l_mention - r_mention)), dim=1)
cated = torch.cat((cated, subword_match_num.view(-1, 1).float()), dim=1)
scores = self.linear_for_cat(torch.cat((cated, lev.view(-1, 1).float()), dim=1))
loss = self.BCEWloss(scores.view(-1), label.float())
output = {'loss': loss}
if self.istrainflag:
binary_class = (torch.sigmoid(scores.view(-1)) > 0.5).int()
self.accuracy(binary_class, label)
return output
@overrides
def get_metrics(self, reset: bool = False):
return {"accuracy": self.accuracy.get_metric(reset)}
エンコーダ部分
import torch.nn as nn
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper, BagOfEmbeddingsEncoder
from allennlp.modules.seq2vec_encoders import BertPooler
from overrides import overrides
from allennlp.nn.util import get_text_field_mask
class Pooler_for_mention(Seq2VecEncoder):
def __init__(self, args, word_embedder):
super(Pooler_for_mention, self).__init__()
self.args = args
self.huggingface_nameloader()
self.bertpooler_sec2vec = BertPooler(pretrained_model=self.bert_weight_filepath)
self.word_embedder = word_embedder
self.word_embedding_dropout = nn.Dropout(self.args.word_embedding_dropout)
def huggingface_nameloader(self):
if self.args.bert_name == 'japanese-bert':
self.bert_weight_filepath = 'cl-tohoku/bert-base-japanese'
else:
self.bert_weight_filepath = 'dummy'
print('Currently not supported', self.args.bert_name)
exit()
def forward(self, contextualized_mention):
mask_sent = get_text_field_mask(contextualized_mention)
mention_emb = self.word_embedder(contextualized_mention)
mention_emb = self.word_embedding_dropout(mention_emb)
mention_emb = self.bertpooler_sec2vec(mention_emb, mask_sent)
return mention_emb
@overrides
def get_output_dim(self):
return 768
今回は非常に単純に、エンコードした特徴を結合し、隠れ層に通すだけにしました。
実験結果
dev acc. ~63%, test acc.~ 64% となりました。特にパラメタチューニング等は行っていません。
今後、時間が出来ればエラー分析も行いたいと思います。
(2021-03-06 追記)上記はコードの不具合により、学習率0.001で回していたことが判明しました。追って学習率~1e-5 の実験が完了次第報告できればと思います。
追記を参照して下さい。
ソースコード
にあげました。
追記(2021-03-28)
実装を見直した所、dev, test acc. ともに85%程度まで上昇しました。
追ってエラー分析も出来ればと思います。
447070it [06:15, 1189.72it/s]
146151it [01:59, 1219.94it/s]
149407it [02:03, 1206.42it/s]
building vocab: 100%|##########| 593221/593221 [00:04<00:00, 129461.72it/s]
You provided a validation dataset but patience was set to None, meaning that early stopping is disabled
accuracy: 0.7831, batch_loss: 0.3731, loss: 0.4547 ||: 100%|##########| 13971/13971 [1:44:42<00:00, 2.22it/s]
accuracy: 0.8297, batch_loss: 0.0949, loss: 0.3775 ||: 100%|##########| 4568/4568 [09:57<00:00, 7.65it/s]
accuracy: 0.8443, batch_loss: 0.3598, loss: 0.3525 ||: 100%|##########| 13971/13971 [1:45:00<00:00, 2.22it/s]
accuracy: 0.8482, batch_loss: 0.0662, loss: 0.3455 ||: 100%|##########| 4568/4568 [09:56<00:00, 7.66it/s]
accuracy: 0.8625, batch_loss: 0.3392, loss: 0.3165 ||: 100%|##########| 13971/13971 [1:44:54<00:00, 2.22it/s]
accuracy: 0.8551, batch_loss: 0.0676, loss: 0.3328 ||: 100%|##########| 4568/4568 [09:56<00:00, 7.66it/s]
accuracy: 0.85, loss: 0.35 ||: : 4669it [10:08, 7.68it/s]
===PARAMETERS===
debug False
dataset bc5cdr
dataset_dir ./dataset/
bert_name japanese-bert
scoring_function_for_model concat
cached_instance False
lr 1e-06
weight_decay 0
beta1 0.9
beta2 0.999
epsilon 1e-08
amsgrad False
word_embedding_dropout 0.1
cuda_devices 0
num_epochs 3
batch_size_for_train 32
batch_size_for_eval 32
debug_sample_num 2000
max_mention_length 30
===PARAMETERS END===
(中略)
Building the vocabulary
{'accuracy': 0.8501408903197306, 'loss': 0.34764245371171687}