はじめに
近年、単語表現を点で表すのではなく、箱のような幅を持つような表現で埋め込む手法が流行ってきています。そこで、私の研究でもこの分散表現を使ってみたいと思い、 Word2Box: Capturing Set-Theoretic Semantics of Words using Box Embeddings の論文を読んでみましたが、内容を理解するのが困難でした。そこで、Github上のソースコードからなんとかモデルの内容について理解しようと試みてみました。理解が不十分なところが多々ありますので、ご教示いただけると幸いです。
提案手法
Box Embedding とは
まず、Box EmbeddingやWord2Boxについて簡単に説明します。 Box Embeddingはその名の通り、Box状に埋め込みます。
例えば、 "bank" という英単語を2次元で埋め込むことを考えます。Box Embeddingでは多くの場合、図のように各次元(軸)に対してBoxの一番小さな座標($X^-$)と一番大きな座標($X^+$)のペアで埋め込むようにします。
$$
\text{bank} = [X^-, X^+] = [[1, 1], [3, 2]]
$$
こうすることで "bank" のような "岸" や "銀行" といった多義的な意味、つまり、言葉の広がりというものをうまく表現できるようになります。
上記のような単語をBoxで表現するような研究が数々行われる中で Word2Box の特徴としては、word2vecでの学習方法のように、周辺語から中心語を予測させるCBOWによる教師無し学習が行われているという点があります。詳しくはこちらの日本語記事をぜひ参照してみてください。これまでの Box Embedding の手法についてもわかりやすく書かれています。
アルゴリズム
Word2BoxではCBOW、すなわち周辺語から中心語を予測するような学習を行っています。その際、正例と周辺語との距離は近づけ、ランダムにサンプリングしてきた負例と周辺語との距離は遠ざけるような max-margin trainingを行います。ここで、
- 正例となる中心語の分散表現を $x_{\text{pos}} \in \mathbb{R}^{2d}$
- $ns$個の負例となる中心語の分散表現を $X_{\text{neg}} = [x_{\text{neg}0}, x_{\text{neg}1}, ..., x_{\text{neg}(ns-1)}] \in \mathbb{R}^{ns\times 2d}$
- 正例と負例をconcatした中心語の分散表現を $X = concat[x_{\text{pos}}, X_{\text{neg}}] \in \mathbb{R}^{(ns+1)\times 2d}$
- 窓幅$w$の周辺語(context)の分散表現を $X_{\text{con}} = [x_{\text{con}0}, x_{\text{con}1}, ... x_{\text{con}(w-1)}]\in \mathbb{R}^{2w\times2d}$
と表します(表記汚くてすみません)。$d$は次元数です。また、それぞれの$x_i\in \mathbb{R}^{2d}$は$x_i[0:d] = x^-_i \in \mathbb{R}^{d}, x_i[d:2d]=x^+ \in \mathbb{R}^{d}$のように構成されています。
まず、contextの分散表現を集約して、集約されたcontextの分散表現$x_{\text{pooled}}^-$と$x^+_{\text{pooled}}$を作成します。
\begin{align}
x^-_{\text{pooled}} &= \beta \log \left(
\sum_{i=0}^{w-1}\exp
\left( \frac{x^-_{\text{con}i}}{\beta} \right)
\right) \in \mathbb{R}^{d} \tag{1.1}\\
x^+_{\text{pooled}} &= -\beta \log \left(
\sum_{i=0}^{w-1}\exp
\left( \frac{x^+_{\text{con}i}}{-\beta}\right)
\right) \in \mathbb{R}^{d} \tag{1.2}
\end{align}
$\beta$はスカラーのハイパーパラメータです。
ここで、集約されたcontextの分散表現を$ns+1$個に複製して、$X^-_{\text{pooled}} \in \mathbb{R}^{(ns+1)\times d}$、と $X^+_{\text{pooled}} \in \mathbb{R}^{(ns+1)\times d}$とします。次に、この集約されたcontextの分散表現と中心語の分散表現の共通部分の分散表現$Z^-\in \mathbb{R}^{(ns+1)\times d}, Z^+ \in \mathbb{R}^{(ns+1)\times d}$を求めます。
\begin{align}
Z^- &= \beta_{g} \log \left(
\exp \left( \frac{X^-}{\beta_g} \right) +
\exp \left( \frac{X_{\text{pooled}}^-}{\beta_g} \right)
\right) \in \mathbb{R}^{(ns+1)\times d} \tag{2.1}\\
Z^- &= \max\left\{ Z^-, \max\{ X^-, X^-_{\text{pooled}} \} \right\} \tag{2.2}\\
\\
Z^+ &= -\beta_{g} \log \left(
\exp \left( \frac{X^-}{-\beta_g} \right) +
\exp \left( \frac{X_{\text{pooled}}^-}{-\beta_g} \right)
\right) \in \mathbb{R}^{(ns+1)\times d} \tag{2.3} \\
Z^+ &= \min\left\{ Z^+, \min\{ X^+, X^+_{\text{pooled}} \} \right\} \tag{2.4}
\end{align}
$\beta_g$はスカラーのハイパーパラメータです。式(3.1)と式(3.3)でガンベル分布を利用しています。
最後に、先ほど求めた共通部分の分散表現を活性化関数(softplus)に入れ、対数を取ったものを埋め込み次元方向に総和をとって類似度 $\boldsymbol{\hat{y}}$ とします。
\begin{align}
\boldsymbol{\hat{y}} &= [\hat{y}_{\text{pos}}, \hat{y}_{\text{neg}0}, ..., \hat{y}_{\text{neg}(ns-1)}] = \sum_{col}\log f(Z^+ - Z^-) \in \mathbb{R}^{ns+1} \tag{3.1}\\
f&(Z^+ - Z^-) = \frac{1}{\beta} \log \left(
1 + \exp\left(\beta ( Z^+ - Z^- -2\beta_g \gamma_{\text{euler}} \right)
\right) \tag{3.2}
\end{align}
$\gamma_{\text{euler}}$はオイラーの定数です。上記式で定義される類似度は共通部分の大きさを意味しています。
また、損失関数では、max-margin lossを採用しています。こちらでは、正例の類似度を大きくし、負例の類似度小さくするようにします。$\mu$はマージンです。
loss = \sum_{i=0}^{ns-1} \max\left( 0, \hat{y}_{\text{neg}i} - \hat{y}_{\text{pos}} + \mu \right) \tag{4}
実装内容
データセットの作成
訓練用データセット、及び、評価用データセットはdatasets.utils
モジュールのget_iter_on_device
で行います。
Vocabularyの作成
## Create Vocabulary properties
print("Creating iterable dataset ...")
TEXT = torchtext.data.Field()
TEXT.stoi = vocab_stoi
TEXT.freqs = vocab_freq
TEXT.itos = [k for k, v in sorted(vocab_stoi.items(), key=lambda item: item[1])]
自然言語処理のあらゆる前処理をサポートしているtorchtext.data.Field()
でvocabularyを作成しています。
-
stoi
: keyに単語、valueにidを持つ辞書 -
freqs
: 語彙作成にあたって使用したデータセット内における語彙の出現頻度。keyに単語、valueに出現頻度(int
)を持つ辞書 -
itos
:stoi
の逆
# Since we won't train on <pad> and <eos>. These should not come in any sort of
# subsampling and negative sampling part.
TEXT.freqs["<pad>"] = 0
TEXT.freqs["<unk>"] = 0
if eos_mask:
TEXT.freqs["<eos>"] = 0
特殊なトークンの出現頻度は0にして、学習に使用しないようにします。eos
とは End Of Sentence の略です。
訓練用データローダー作成
## Create data on the device
print("Creating iterable dataset on GPU/CPU...")
if data_device == "gpu":
data_device = device
train_iter = LazyDatasetLoader(
training_tensor=train_tokenized, # ← train_tokenized について見てみる
n_splits=1000,
window_size=n_gram,
vocab=TEXT,
subsample_thresh=subsample_thresh,
eos_mask=eos_mask,
device=data_device,
batch_size=batch_size,
)
train_tokenized
をdatasets.word2vecgpu
のLazyDatasetLoader
に入力して訓練用データローダーtrain_iter
を作成しています。
訓練用データ train_tokenized
について
下記のコードはtrain.pt
の中身を見るための自作コードです。
# 訓練用データセット train.pt の中身を見る
train_tokenized = torch.load("./data/ptb/train.pt")
sub_train_tokenized = train_tokenized[0:100] # 100トークン抽出
print(sub_train_tokenized)
# itos (IDから文字列) の辞書を作成
vocab_stoi = json.load(open("./data/ptb/vocab_stoi.json", "r"))
vocab_itos = [k for k, v in sorted(vocab_stoi.items(), key=lambda item: item[1])]
# トークン列を文字列にマッピング
sub_train_texts = [vocab_itos[token] for token in sub_train_tokenized]
print(sub_train_texts)
data/ptb/
にあるtrain.pt
というバイナリファイルに訓練用データセットが保存されています。それをロードすると、サイズ[929589]のTensorを取得することができます。上記のコードを実行した結果が以下のようになります。
python test.py
tensor([ 9971, 9972, 9973, 9975, 9976, 9977, 9981, 9982, 9983, 9984,
9985, 9987, 9988, 9989, 9990, 9992, 9993, 9994, 9995, 9996,
9997, 9998, 9999, 10000, 3, 9257, 0, 4, 73, 394,
34, 2134, 2, 147, 20, 7, 9208, 277, 408, 4,
3, 24, 0, 14, 142, 5, 0, 5466, 2, 3082,
1597, 97, 3, 7683, 0, 4, 73, 394, 9, 338,
142, 5, 2478, 658, 2171, 956, 25, 522, 7, 9208,
277, 5, 40, 304, 439, 3685, 3, 7, 943, 5,
3151, 497, 264, 6, 139, 6093, 4242, 6037, 31, 989,
7, 242, 761, 5, 1016, 2787, 212, 7, 97, 5])
['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack-food', 'ssangyong', 'swapo', 'wachter', '<eos>', 'pierre', '<unk>', 'N', 'years', 'old', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'nov.', 'N', '<eos>', 'mr.', '<unk>', 'is', 'chairman', 'of', '<unk>', 'n.v.', 'the', 'dutch', 'publishing', 'group', '<eos>', 'rudolph', '<unk>', 'N', 'years', 'old', 'and', 'former', 'chairman', 'of', 'consolidated', 'gold', 'fields', 'plc', 'was', 'named', 'a', 'nonexecutive', 'director', 'of', 'this', 'british', 'industrial', 'conglomerate', '<eos>', 'a', 'form', 'of', 'asbestos', 'once', 'used', 'to', 'make', 'kent', 'cigarette', 'filters', 'has', 'caused', 'a', 'high', 'percentage', 'of', 'cancer', 'deaths', 'among', 'a', 'group', 'of']
以下のような前処理が施されています。(2023/01/19追記)
- トークナイズ、レンマ化されている
- punctuation(カンマ等)が取り除かれている
-
<eos>
で文が区切られている - 数字は
N
に置換されている - すべての文字が小文字化されている
- アスキーコード以外の文字は除かれている
- 出現頻度が100未満の語は語彙に含まない。
<unk>
で置き換えられているのはこれだと思う。
上記のtrain_tokenized
がデータセットとしてどのように処理されているのか見てみましょう。
torch.utils.data.Dataset
を使った訓練用データセットの作成
train_tokenized
をLazyDatasetLoader
に入力して訓練用データローダーが作成されていますが、このLazyDatasetLoader
はtorch.utils.data
のDataset
クラスを継承しているものになっています。
データを作成する際に重要となる箇所のみ読み解きたいと思います。
def __getitem__(
self, idx: LongTensor
) -> Tuple[LongTensor, LongTensor, BoolTensor, BoolTensor]:
# idx is a Tensor of indicies of the corpus, eg. [2342,12312312,34534,1]
# we will interpret these as the id of the center word
idx += self.pad_size
# Idx is repeated to get the sliding window effect
# For the sliding window part we add the range with idx
window_range = torch.arange(-self.window_size, self.window_size + 1)
idx = idx.unsqueeze(1) + window_range.unsqueeze(0)
# idx = torch.transpose(idx.repeat(2*self.window_size+1,1), 0, 1)
# idx = idx + torch.arange(-self.window_size, self.window_size+1)
# Get the middle slice for the center
# The rest of them are context
center = self.corpus[idx[:, self.window_size]]
context = self.corpus[
torch.cat(
(idx[:, : self.window_size], idx[:, self.window_size + 1 :]), dim=1
)
]
# Get do the subsampling.
center = self.sub_sample_words(center)
context = self.sub_sample_words(context)
# Get rid of the dataset that has the center word as <pad>.
# Or has all context words as <pad>.
if not self.eos_mask:
keep = (center != self.pad_id) & (context != self.pad_id).any(dim=-1)
center = center[keep]
context = context[keep]
assert (center != self.pad_id).all()
context_mask = torch.ones_like(context)
else:
keep = (
(center != self.pad_id)
& (context != self.pad_id).any(dim=-1)
& (center != self.eos_token)
& (context != self.eos_token).any(dim=-1)
)
center = center[keep]
context = context[keep]
assert (center != self.pad_id).all()
context_mask = self.get_mask(context)
# Mask might do away with the whole sentence. In that case remove that
keep = (context_mask != False).any(dim=1).squeeze()
center = center[keep]
context = context[keep]
context_mask = context_mask[keep]
return {
"center_word": center,
"context_words": context,
"context_mask": context_mask,
}
上記のコードについてはよく理解できていませんが、訓練用データセットtrain_iter
として使用されるのはこの返り値となっています。
-
center_word
:size() == [B]
-
context_words
:size() == [B, window_size]
-
context_mask
:size() == [B, window_size]
for i, batch in enumerate(train_iter):
print(TEXT.itos[batch["center_word"][0]])
context_words = [TEXT.itos[word] for word in batch["context_words"][0]]
print(context_words)
print(batch["context_mask"][0].to('cpu').detach().numpy().tolist())
"""
2021年
['日本語', ').', 'NHK', 'みんなのうた', '.', '3月20日', '閲覧', '<pad>', '<eos>', '大槻']
[True, True, True, True, True, True, True, False, False, False]
調査
['<pad>', '生息', '実態', '<pad>', '聞き込み', 'による', '過去', '<pad>', '生息', '情報']
[False, True, True, False, True, True, True, False, True, True]
2020年
['<pad>', '地下鉄', '』', '##ぎょうせい##', '<pad>', '10月14日', '<pad>', '17', '頁', '<pad>']
[False, True, True, True, False, True, False, True, True, False]
筋力トレーニング
['効果', '<pad>', '上げる', 'ため', 'に', '<pad>', '<pad>', '。', '##1990年##', '<pad>']
[True, False, True, True, True, False, False, True, True, False]
tensor(33, device='cuda:0')
"""
(2023/01/26 追記)
データローダーの中身を上記のコードで調べてみました。窓幅を5にしているので、周辺語は10個になっており、中心語は周辺語のちょうど真ん中の位置にあるようになっています。また、<pad>
や<eos>
はマスクでFalse
になっています。
また、<eos>
トークンが周辺語の前半にある場合、<eos>
トークンを含むそれ以前の文がFalse
にマスクされ、<eos>
トークンが周辺語の後半にある場合、<eos>
トークンを含むそれ以降の文がFalse
にマスクされています。
モデルの実装・訓練
train.Trainer.TrainerWordSimilarity
クラスにあるtrain_model
でモデルの実装・訓練が行われています。train_model
の詳細についてみてましょう。
optimizerの設定
## Setting up the optimizers
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=self.lr)
filter関数によって、勾配を保存している parameter のみを抽出して、optimizer に Adam を使用しています。
negative sampling
# Create negative samples for the batch
batch = self.to(batch, device)
batch = self.add_negatives(batch)
add_negatives
によって negative samplingが行われています。add_negatives
にはnegative_sampling
ファイル内のRandomNegativeCBOW
クラスかRandomNegativeSkipGram
クラスが格納されています。
RandomNegativeCBOW
class RandomNegativeCBOW:
"""
This augments a batch of data to include randomly sampled target center words.
Appends the sampled words with 'center_words' of the batch
"""
def __init__(self, number_of_samples: int = 5, sampling_distn: LongTensor = None):
self.number_of_samples = number_of_samples
self.sampling_distn = sampling_distn
def __call__(self, batch) -> LongTensor:
x, y = batch["context_words"].shape
negatives = torch.multinomial(
self.sampling_distn,
num_samples=self.number_of_samples * x,
replacement=True,
).resize(x, self.number_of_samples)
batch["center_word"] = torch.cat(
(batch["center_word"].unsqueeze(1), negatives), dim=-1
)
return batch
RandomNegativeCBOW
では入力されたbatch
のcenter_word
を更新します。
- 入力前:
center_word.size() == [B]
(B
はバッチサイズ) - 入力後:
center_word.size() == [B, ns+1]
(ns
はネガティブサンプリング数)
偽の中心語がネガティブサンプリングされた形式になります。
RandomNegativeSkipGram
RandomNegativeSkipGram
クラスでは偽のcontextsをネガティブサンプリングする形式となっています。デフォルトでは前述のCBOW
の方式で訓練を行なっています。
スコアの計算(順伝搬)
# Start the optimization
optimizer.zero_grad()
score = model.forward(
batch["center_word"],
batch["context_words"],
batch["context_mask"],
train=True,
)
順伝搬をどのように計算しているかを詳しく見てみましょう。使用しているモデルはWord2Box
を継承しているWord2BoxConjunction
となっているので、まずWord2Box
の実装についてみていこうと思います。
Word2Box
(埋め込み表現の初期化)
class Word2Box(BaseModule):
def __init__(
self,
TEXT=None,
embedding_dim=50,
batch_size=10,
n_gram=4,
volume_temp=1.0,
intersection_temp=1.0,
box_type="BoxTensor",
**kwargs
):
super(Word2Box, self).__init__()
# Model
self.batch_size = batch_size
self.n_gram = n_gram
self.vocab_size = len(TEXT.itos)
self.embedding_dim = embedding_dim
# Box features
self.volume_temp = volume_temp
self.intersection_temp = intersection_temp
self.box_type = box_type
# Create embeddings
self.embeddings_word = BoxEmbedding(
self.vocab_size, self.embedding_dim, box_type=box_type
)
self.embedding_context = BoxEmbedding(
self.vocab_size, self.embedding_dim, box_type=box_type
)
Word2BoxConjunction
に関係がある__init__
関数のみ焦点を当てます。
ここでは、BoxEmbedding
クラスにより、vocabularyの埋め込み表現の初期化を行なっています。torch.nn.init.uniform_
により一様分布で初期化されています。初期化時のテンソルのサイズは[box_embedding_dim*2]
となっており、テンソルの後ろ半分は前半分の値を+0.1した値となっています。
Word2BoxConjunction
(順伝搬)
class Word2BoxConjunction(Word2Box):
def intersect_multiple_box(self, boxes, mask):
beta = self.intersection_temp
z = boxes.z.clone()
Z = boxes.Z.clone()
# maskが-infなので、対数を取ると限りなく0に近づく
z[~mask] = float("-inf")
Z[~mask] = float("inf")
z = beta * torch.logsumexp(z / beta, dim=1, keepdim=True)
Z = -beta * torch.logsumexp(-Z / beta, dim=1, keepdim=True)
return BoxTensor.from_zZ(z, Z) # [2, box_embedding_dim]
def forward(self, idx_word, idx_context, mask_context, train=True):
context_boxes = self.embedding_context(idx_context) # Batch_size * 2 * dim
# Notce that the context is not masked yet. Need to mask them as well.
word_boxes = self.embeddings_word(idx_word) # Batch_size * ns+1 * 2 * dim
pooled_context = self.intersect_multiple_box(context_boxes, mask_context) # Batch_size *
if self.intersection_temp == 0.0:
score = word_boxes.intersection_log_soft_volume(
pooled_context, temp=self.volume_temp
)
else:
score = word_boxes.gumbel_intersection_log_volume(
pooled_context,
volume_temp=self.volume_temp,
intersection_temp=self.intersection_temp,
)
return score
先ほどのWord2Box
を継承しているWord2BoxConjunction
について詳しくみてます。このクラスのforward
が実際にスコアの計算(順伝搬)をしているところとなります。
intersect_multiple_box
では複数のBoxの共通部分のBoxを計算するところとなっています。
$$
\operatorname{logsumexp}(x) = \log \sum_i \exp(x_i)
$$
上記の式のように、LogSumExpと呼ばれる方法で共通部分を計算しています。
そしてforward
では、context
の共通部分であるpooled_context
と中心語のboxであるword_boxes
との類似度(スコア)を計算しています。それぞれのテンソルのサイズは以下の通りになっています。
print(f"idx_context.size()=={idx_context.size()}")
z = context_boxes.z.clone()
print(f"context_boxes.z.size()=={z.size()}")
print(f"idx_word.size()=={idx_word.size()}")
z = word_boxes.z.clone()
print(f"word_boxes.z.size()=={z.size()}")
z = pooled_context.z.clone()
print(f"pooled_context.z.size()=={z.size()}")
"""
idx_context.size()==torch.Size([2767, 10]) # [B, 2*window_size]
context_boxes.size()==torch.Size([2767, 10, 64]) # [B, 2*window_size, embedding_dim]
idx_word.size()==torch.Size([2767, 3]) # [B, ns+1]
word_boxes.size()==torch.Size([2767, 3, 64]) # [B, ns+1, embedding_dim]
pooled_context.z.size()==torch.Size([2767, 1, 64]) # [B, 1, embedding_dim]
"""
次項で2通りのスコアの算出方法について詳細にみてみようと思います。
intersection_log_soft_volume
def intersection_log_soft_volume(
self,
other: TBoxTensor,
temp: float = 1.0,
gumbel_beta: float = 1.0,
bayesian: bool = False,
scale: Union[float, Tensor] = 1.0,
) -> Tensor:
z, Z = self._intersection(other, gumbel_beta, bayesian)
vol = self._log_soft_volume(z, Z, temp=temp, scale=scale)
return vol
self._intersection
で中心語とコンテキストの共通部分を算出しています。
def _intersection(
self: TBoxTensor,
other: TBoxTensor,
gumbel_beta: float = 1.0,
bayesian: bool = False,
) -> Tuple[Tensor, Tensor]:
t1 = self
t2 = other
# ~~~略~~~ #
# intersection_log_soft_volume で使う
else:
z = torch.max(t1.z, t2.z)
Z = torch.min(t1.Z, t2.Z)
return z, Z
intersection_log_soft_volume
中の_intersection
では、単純にz
($X^-$)の最大値とZ
($X^+$)の最小値をとって、中心語とコンテキストの共通部分を計算しています。
@classmethod
def _log_soft_volume(
cls, z: Tensor, Z: Tensor, temp: float = 1.0, scale: Union[float, Tensor] = 1.0
) -> Tensor:
eps = torch.finfo(z.dtype).tiny # type: ignore
if isinstance(scale, float):
s = torch.tensor(scale)
else:
s = scale
return torch.sum(
torch.log(F.softplus(Z - z, beta=temp) + 1e-23), dim=-1
) + torch.log(
s
) # need this eps to that the derivative of log does not blow
そして、_log_soft_volume
では、共通部分のz
とZ
をもとにsoftplusという活性化関数を用いてスコアを算出しています。
score = \log\left(\frac{1}{\beta}・\log(1+\exp(\beta・(X^+ - X^-))\right)
softplus関数はReLUと同じような形をしていますが、
- 関数への入力値が0付近では、出力値が0にはならない
- 入力値が小さくなればなるほど出力値が0に近づいていき、入力値が大きくなればなるほど出力値が入力値と同じ値に近づいていく
といった特徴があります。
gumbel_intersection_log_volume
def gumbel_intersection_log_volume(
self: TBoxTensor,
other: TBoxTensor,
volume_temp=1.0,
intersection_temp: float = 1.0,
scale=1.0,
) -> TBoxTensor:
z, Z = self._intersection(other, gumbel_beta=intersection_temp, bayesian=True)
vol = self._log_soft_volume_adjusted(
z, Z, temp=volume_temp, gumbel_beta=intersection_temp, scale=scale
)
return vol
こちらでも前述と同様に_intersection
で中心語とコンテキストの共通部分を算出しています。
def _intersection(
self: TBoxTensor,
other: TBoxTensor,
gumbel_beta: float = 1.0,
bayesian: bool = False,
) -> Tuple[Tensor, Tensor]:
t1 = self
t2 = other
if bayesian:
try:
z = gumbel_beta * torch.logaddexp(
t1.z / gumbel_beta, t2.z / gumbel_beta
)
z = torch.max(z, torch.max(t1.z, t2.z))
Z = -gumbel_beta * torch.logaddexp(
-t1.Z / gumbel_beta, -t2.Z / gumbel_beta
)
Z = torch.min(Z, torch.min(t1.Z, t2.Z))
except Exception as e:
print("Gumbel intersection is not possible")
breakpoint()
# ~~~略~~~ #
return z, Z
gumbel_intersection_log_volume
中の_intersection
ではbayesian
モードで共通部分を算出しています。数式で表すと以下のようになります。
\begin{align}
X^- &= \beta_g \log \left(
\exp \left( \frac{X^-_{t1}}{\beta_g} \right)
+ \exp \left( \frac{X^-_{t2}}{\beta_g} \right)
\right) \\
X^- &= \max\left\{ X^-, \max\{X^-_{t1}, X^-_{t2}\} \right\} \\ \\
X^+ &= -\beta_g \log \left(
\exp \left( \frac{-X^+_{t1}}{\beta_g} \right)
+ \exp \left( \frac{-X^+_{t2}}{\beta_g} \right)
\right) \\
X^+ &= \min\left\{ X^+, \min\{X^+_{t1}, X^+_{t2}\} \right\}
\end{align}
Gumbel分布を利用してBoxの頂点を算出しています。Gumbel分布を利用することで、局所識別性(local indentifiability)を向上させることを望んでいます。
@classmethod
def _log_soft_volume_adjusted(
cls,
z: Tensor,
Z: Tensor,
temp: float = 1.0,
gumbel_beta: float = 1.0,
scale: Union[float, Tensor] = 1.0,
) -> Tensor:
eps = torch.finfo(z.dtype).tiny # type: ignore
if isinstance(scale, float):
s = torch.tensor(scale)
else:
s = scale
return (
torch.sum(
torch.log(
F.softplus(Z - z - 2 * euler_gamma * gumbel_beta, beta=temp) + 1e-23
),
dim=-1,
)
+ torch.log(s)
)
そして、_log_soft_volume_adjusted
では、前述の_log_soft_volume
と同様に、共通部分のz
とZ
をもとにsoftplus関数を用いてスコアを算出しています。
score = \log\left(\frac{1}{\beta}・\log(1+\exp(\beta・(X^+ - X^- -2\beta_{g} \gamma_{\operatorname{euler}}))\right)
Loss function (損失関数)
assert (
score.shape[-1] == self.negative_samples + 1 # .size()==[B, ns+1]
) # check the shape of the score
# Score log_intersection_volume (un-normalised) for Word2Box
pos_score = score[..., 0].reshape( # .size()==[B, 1]
-1, 1
) # The first element correspond to the Positive
neg_score = score[..., 1:].reshape( # .size()==[B, ns]
-1, self.negative_samples
) # The rest of the elements are for negative samples
損失を計算する前に positive sample のスコアと negative samples のスコアを分けます。
# Calculate Loss
loss = self.loss_fn(
pos_score, neg_score, margin=self.margin
) # Margin is not required for nll or nce
# Handled through kwargs in loss.py
total_loss = torch.sum(loss)
avg_loss = torch.mean(loss)
if torch.isnan(loss).any():
raise RuntimeError("Loss value is nan :(")
loss_fn
として定義される損失関数はloss.py
中にある、nll
(Negative Log Likelihood)、nce
(Noise Contrastive Estimation)、max_margin
の3種類から選択できますが、ソールコード上でも論文でもmax_margin
が採用されていました。
max_margin
def max_margin(pos, neg, margin=5.0):
"""
This is max margin loss for box embeddings.
Here, the input scores can be un-normalised. The object here
is to make increse the pos similarity score more than a margin
from the negative scores. If that margin is satisfied then the
loss is zero.
Args:
pos: Unnormalised similarity(maybe log in case of Boxes) score for positives.
neg: Unnormalised similarity(maybe log in case of Boxes) score for negatives.
Output:
loss = - max(0, pos - mean(neg) + margin)
"""
# Replicate the positive score number of negative sample times
zero = torch.tensor(0.0).to(device)
return torch.sum(torch.max(zero, neg - pos + margin), dim=1)
loss = \sum \max\left( 0, score_{\text{neg}i} - score_{\text{pos}} + \mu \right)
モデルの評価
モデルの評価はtrain.Trainer.TrainerWordSimilarity
クラスのmodel_eval
で行います。
def word_similarity(self, w1, w2):
with torch.no_grad():
word1 = self.embeddings_word(w1)
word2 = self.embeddings_word(w2)
if self.intersection_temp == 0.0:
score = word1.intersection_log_soft_volume(word2, temp=self.volume_temp)
else:
score = word1.gumbel_intersection_log_volume(
word2,
volume_temp=self.volume_temp,
intersection_temp=self.intersection_temp,
)
return score
評価に使用する予測値は、上記のword_similarity
関数の出力値を利用します。2つの単語に対してintersection_log_soft_volume
関数もしくは、gumbel_intersection_log_volume
関数を用いてスコアの計算を行っています。こちらの計算方法については前述しているのでそちらを見返してみてください。
correlation = spearmanr(predicted_scores, real_scores)[0]
そして、先ほど得た予測値と評価用データセットのラベル値に対して、スピアマンの順位相関係数で評価を行っています。スピアマン順位相関係数は、2変数が順位データや5段階評価データなどの順序尺度の場合に有効な評価手法です。
学習済みモデルの利用方法
モデルを読み込む
モデルを呼び出すための実行ファイルはsrc
と同階層に配置しました。
@click
のパラメータの呼び出し方法が理解できていないため、とりあえず不細工ですがこのようにしました。絶対にもっといい方法があると思います。。。
もっといいモデルのloadの仕方わかる方いたら是非ご教示いただきたいです!
# モデルのインスタンス作成に必要なモジュールをインポート
from src.language_modeling_with_boxes.models import Word2Box, Word2Vec, Word2VecPooled, Word2BoxConjunction, Word2Gauss
from src.language_modeling_with_boxes.datasets.utils import get_iter_on_device
from src.language_modeling_with_boxes.__main__ import main
import torch
from torch import LongTensor, BoolTensor, Tensor, IntTensor
import pickle, json
# 保存してあるモデルと同じパラメータを設定 (すごい無理矢理ですが)
config = {
"batch_size": 4096,
"box_type": "BoxTensor",
"data_device": "gpu",
"dataset": "ptb",
"embedding_dim": 64,
"eos_mask": True,
"eval_file": "../data/similarity_datasets/",
"int_temp": 1.9678289474987882,
"log_frequency": 10,
"loss_fn": "max_margin",
"lr": 0.004204091643267762,
"margin": 5,
"model_type": "Word2BoxConjunction",
"n_gram": 5,
"negative_samples": 2,
"num_epochs": 10,
"subsample_thresh": 0.001,
"vol_temp": 0.33243242379830407,
"save_model": "",
"add_pad": "",
"save_dir": "results",
}
# 語彙やデータローダーを作成
TEXT, train_iter, val_iter, test_iter, subsampling_prob = get_iter_on_device(
config["batch_size"],
config["dataset"],
config["model_type"],
config["n_gram"],
config["subsample_thresh"],
config["data_device"],
config["add_pad"],
config["eos_mask"],
)
# モデルのインスタンスを作成
model = Word2BoxConjunction(
TEXT=TEXT,
embedding_dim=config["embedding_dim"],
batch_size=config["batch_size"],
n_gram=config["n_gram"],
intersection_temp=config["int_temp"],
volume_temp=config["vol_temp"],
box_type=config["box_type"],
)
# 作成したインスタンスに保存してあるパラメータを読み込む
model.load_state_dict(torch.load('results/best_model.ckpt'))
print(model)
類似度を計算する
word_1 = "dog"
word_2 = "cat"
# 文字列 を ID に変換
word_1_id = TEXT.stoi["dog"]
word_2_id = TEXT.stoi["cat"]
# INT の ID を LongTensorにキャストした後で類似度を計算
similarity = model.word_similarity(LongTensor([word_1_id]), LongTensor([word_2_id]))
print(similarity)
参考文献
- Dasgupta, Shib, et al. "Word2Box: Capturing Set-Theoretic Semantics of Words using Box Embeddings." Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2022.
- Word2Boxの著者らによるソースコード(Github)
- 萩原 正人. "単語を箱で表現!新たな埋め込み手法 Box Embedding を基礎から理解." ステート・オブ・AIガイド. 2022/10/13.
- Pythonのfilter()関数で要素を抽出!map、reduceとの違いも解説
- ◆スピアマン順位相関係数の無相関検定◆
- スピアマンの順位相関係数
- TORCHTEXT(公式)
- torchtextの仕様変更対応 (1) Field
- TORCH.UTILS.DATA(公式)
- TORCH.NN.INIT(公式)
- TORCH.LOGSUMEXP(公式)
- logsumexpとは|Numpy・PyTorchによる実装例も解説!
- SOFTPLUS(公式)
- [活性化関数]ソフトプラス関数(Softplus関数)とは?
- TORCH.LOGADDEXP (公式)
- Dasgupta, Shib, et al. "Improving local identifiability in probabilistic box embeddings." Advances in Neural Information Processing Systems 33 (2020): 182-192.