18
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Word2Boxの実装を読み解く

Last updated at Posted at 2022-12-20

はじめに

近年、単語表現を点で表すのではなく、箱のような幅を持つような表現で埋め込む手法が流行ってきています。そこで、私の研究でもこの分散表現を使ってみたいと思い、 Word2Box: Capturing Set-Theoretic Semantics of Words using Box Embeddings の論文を読んでみましたが、内容を理解するのが困難でした。そこで、Github上のソースコードからなんとかモデルの内容について理解しようと試みてみました。理解が不十分なところが多々ありますので、ご教示いただけると幸いです。

提案手法

Box Embedding とは

まず、Box EmbeddingやWord2Boxについて簡単に説明します。 Box Embeddingはその名の通り、Box状に埋め込みます。
スクリーンショット 2022-12-16 20.47.34.png
例えば、 "bank" という英単語を2次元で埋め込むことを考えます。Box Embeddingでは多くの場合、図のように各次元(軸)に対してBoxの一番小さな座標($X^-$)と一番大きな座標($X^+$)のペアで埋め込むようにします。
$$
\text{bank} = [X^-, X^+] = [[1, 1], [3, 2]]
$$
こうすることで "bank" のような "岸" や "銀行" といった多義的な意味、つまり、言葉の広がりというものをうまく表現できるようになります。

上記のような単語をBoxで表現するような研究が数々行われる中で Word2Box の特徴としては、word2vecでの学習方法のように、周辺語から中心語を予測させるCBOWによる教師無し学習が行われているという点があります。詳しくはこちらの日本語記事をぜひ参照してみてください。これまでの Box Embedding の手法についてもわかりやすく書かれています。

アルゴリズム

Word2BoxではCBOW、すなわち周辺語から中心語を予測するような学習を行っています。その際、正例と周辺語との距離は近づけ、ランダムにサンプリングしてきた負例と周辺語との距離は遠ざけるような max-margin trainingを行います。ここで、

  • 正例となる中心語の分散表現を $x_{\text{pos}} \in \mathbb{R}^{2d}$
  • $ns$個の負例となる中心語の分散表現を $X_{\text{neg}} = [x_{\text{neg}0}, x_{\text{neg}1}, ..., x_{\text{neg}(ns-1)}] \in \mathbb{R}^{ns\times 2d}$
  • 正例と負例をconcatした中心語の分散表現を $X = concat[x_{\text{pos}}, X_{\text{neg}}] \in \mathbb{R}^{(ns+1)\times 2d}$
  • 窓幅$w$の周辺語(context)の分散表現を $X_{\text{con}} = [x_{\text{con}0}, x_{\text{con}1}, ... x_{\text{con}(w-1)}]\in \mathbb{R}^{2w\times2d}$

と表します(表記汚くてすみません)。$d$は次元数です。また、それぞれの$x_i\in \mathbb{R}^{2d}$は$x_i[0:d] = x^-_i \in \mathbb{R}^{d}, x_i[d:2d]=x^+ \in \mathbb{R}^{d}$のように構成されています。

まず、contextの分散表現を集約して、集約されたcontextの分散表現$x_{\text{pooled}}^-$と$x^+_{\text{pooled}}$を作成します。

\begin{align}
x^-_{\text{pooled}} &= \beta \log \left(
\sum_{i=0}^{w-1}\exp
\left( \frac{x^-_{\text{con}i}}{\beta} \right) 
\right) \in \mathbb{R}^{d} \tag{1.1}\\
x^+_{\text{pooled}} &= -\beta \log \left(
\sum_{i=0}^{w-1}\exp
\left( \frac{x^+_{\text{con}i}}{-\beta}\right) 
\right) \in \mathbb{R}^{d} \tag{1.2}
\end{align}

$\beta$はスカラーのハイパーパラメータです。
ここで、集約されたcontextの分散表現を$ns+1$個に複製して、$X^-_{\text{pooled}} \in \mathbb{R}^{(ns+1)\times d}$、と $X^+_{\text{pooled}} \in \mathbb{R}^{(ns+1)\times d}$とします。次に、この集約されたcontextの分散表現と中心語の分散表現の共通部分の分散表現$Z^-\in \mathbb{R}^{(ns+1)\times d}, Z^+ \in \mathbb{R}^{(ns+1)\times d}$を求めます。

\begin{align}
Z^- &= \beta_{g} \log \left(
\exp \left( \frac{X^-}{\beta_g} \right) +
\exp \left( \frac{X_{\text{pooled}}^-}{\beta_g} \right)
\right) \in  \mathbb{R}^{(ns+1)\times d} \tag{2.1}\\
Z^- &= \max\left\{ Z^-, \max\{ X^-, X^-_{\text{pooled}} \} \right\} \tag{2.2}\\
 \\
Z^+ &= -\beta_{g} \log \left(
\exp \left( \frac{X^-}{-\beta_g} \right) +
\exp \left( \frac{X_{\text{pooled}}^-}{-\beta_g} \right)
\right) \in  \mathbb{R}^{(ns+1)\times d} \tag{2.3} \\
Z^+ &= \min\left\{ Z^+, \min\{ X^+, X^+_{\text{pooled}} \} \right\} \tag{2.4}
\end{align}

$\beta_g$はスカラーのハイパーパラメータです。式(3.1)と式(3.3)でガンベル分布を利用しています。
最後に、先ほど求めた共通部分の分散表現を活性化関数(softplus)に入れ、対数を取ったものを埋め込み次元方向に総和をとって類似度 $\boldsymbol{\hat{y}}$ とします。

\begin{align}
\boldsymbol{\hat{y}} &= [\hat{y}_{\text{pos}}, \hat{y}_{\text{neg}0}, ..., \hat{y}_{\text{neg}(ns-1)}] = \sum_{col}\log f(Z^+ - Z^-) \in \mathbb{R}^{ns+1} \tag{3.1}\\
f&(Z^+ - Z^-) = \frac{1}{\beta} \log \left( 
1 + \exp\left(\beta ( Z^+ - Z^- -2\beta_g \gamma_{\text{euler}} \right)
\right) \tag{3.2}
\end{align}

$\gamma_{\text{euler}}$はオイラーの定数です。上記式で定義される類似度は共通部分の大きさを意味しています。
また、損失関数では、max-margin lossを採用しています。こちらでは、正例の類似度を大きくし、負例の類似度小さくするようにします。$\mu$はマージンです。

loss = \sum_{i=0}^{ns-1} \max\left( 0, \hat{y}_{\text{neg}i} - \hat{y}_{\text{pos}} + \mu \right) \tag{4}

実装内容

データセットの作成

訓練用データセット、及び、評価用データセットはdatasets.utilsモジュールのget_iter_on_deviceで行います。

Vocabularyの作成

## Create Vocabulary properties
print("Creating iterable dataset ...")
TEXT = torchtext.data.Field()
TEXT.stoi = vocab_stoi
TEXT.freqs = vocab_freq
TEXT.itos = [k for k, v in sorted(vocab_stoi.items(), key=lambda item: item[1])]

自然言語処理のあらゆる前処理をサポートしているtorchtext.data.Field()でvocabularyを作成しています。

  • stoi: keyに単語、valueにidを持つ辞書
  • freqs: 語彙作成にあたって使用したデータセット内における語彙の出現頻度。keyに単語、valueに出現頻度(int)を持つ辞書
  • itos: stoiの逆
# Since we won't train on <pad> and <eos>. These should not come in any sort of
# subsampling and negative sampling part.
TEXT.freqs["<pad>"] = 0
TEXT.freqs["<unk>"] = 0

if eos_mask:
    TEXT.freqs["<eos>"] = 0

特殊なトークンの出現頻度は0にして、学習に使用しないようにします。eosとは End Of Sentence の略です。

訓練用データローダー作成

## Create data on the device
print("Creating iterable dataset on GPU/CPU...")
if data_device == "gpu":
    data_device = device
train_iter = LazyDatasetLoader(
    training_tensor=train_tokenized,  # ← train_tokenized について見てみる
    n_splits=1000,
    window_size=n_gram,
    vocab=TEXT,
    subsample_thresh=subsample_thresh,
    eos_mask=eos_mask,
    device=data_device,
    batch_size=batch_size,
)

train_tokenizeddatasets.word2vecgpuLazyDatasetLoaderに入力して訓練用データローダーtrain_iterを作成しています。

訓練用データ train_tokenized について

下記のコードはtrain.ptの中身を見るための自作コードです。

# 訓練用データセット train.pt の中身を見る
train_tokenized = torch.load("./data/ptb/train.pt")
sub_train_tokenized = train_tokenized[0:100]  # 100トークン抽出
print(sub_train_tokenized)

# itos (IDから文字列) の辞書を作成
vocab_stoi = json.load(open("./data/ptb/vocab_stoi.json", "r"))
vocab_itos = [k for k, v in sorted(vocab_stoi.items(), key=lambda item: item[1])]

# トークン列を文字列にマッピング
sub_train_texts = [vocab_itos[token] for token in sub_train_tokenized]
print(sub_train_texts)

data/ptb/にあるtrain.ptというバイナリファイルに訓練用データセットが保存されています。それをロードすると、サイズ[929589]のTensorを取得することができます。上記のコードを実行した結果が以下のようになります。

python test.py 
tensor([ 9971,  9972,  9973,  9975,  9976,  9977,  9981,  9982,  9983,  9984,
         9985,  9987,  9988,  9989,  9990,  9992,  9993,  9994,  9995,  9996,
         9997,  9998,  9999, 10000,     3,  9257,     0,     4,    73,   394,
           34,  2134,     2,   147,    20,     7,  9208,   277,   408,     4,
            3,    24,     0,    14,   142,     5,     0,  5466,     2,  3082,
         1597,    97,     3,  7683,     0,     4,    73,   394,     9,   338,
          142,     5,  2478,   658,  2171,   956,    25,   522,     7,  9208,
          277,     5,    40,   304,   439,  3685,     3,     7,   943,     5,
         3151,   497,   264,     6,   139,  6093,  4242,  6037,    31,   989,
            7,   242,   761,     5,  1016,  2787,   212,     7,    97,     5])
['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack-food', 'ssangyong', 'swapo', 'wachter', '<eos>', 'pierre', '<unk>', 'N', 'years', 'old', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'nov.', 'N', '<eos>', 'mr.', '<unk>', 'is', 'chairman', 'of', '<unk>', 'n.v.', 'the', 'dutch', 'publishing', 'group', '<eos>', 'rudolph', '<unk>', 'N', 'years', 'old', 'and', 'former', 'chairman', 'of', 'consolidated', 'gold', 'fields', 'plc', 'was', 'named', 'a', 'nonexecutive', 'director', 'of', 'this', 'british', 'industrial', 'conglomerate', '<eos>', 'a', 'form', 'of', 'asbestos', 'once', 'used', 'to', 'make', 'kent', 'cigarette', 'filters', 'has', 'caused', 'a', 'high', 'percentage', 'of', 'cancer', 'deaths', 'among', 'a', 'group', 'of']

以下のような前処理が施されています。(2023/01/19追記)

  • トークナイズ、レンマ化されている
  • punctuation(カンマ等)が取り除かれている
  • <eos>で文が区切られている
  • 数字はNに置換されている
  • すべての文字が小文字化されている
  • アスキーコード以外の文字は除かれている
  • 出現頻度が100未満の語は語彙に含まない。<unk>で置き換えられているのはこれだと思う。

上記のtrain_tokenizedがデータセットとしてどのように処理されているのか見てみましょう。

torch.utils.data.Datasetを使った訓練用データセットの作成

train_tokenizedLazyDatasetLoaderに入力して訓練用データローダーが作成されていますが、このLazyDatasetLoadertorch.utils.dataDatasetクラスを継承しているものになっています。
データを作成する際に重要となる箇所のみ読み解きたいと思います。

    def __getitem__(
        self, idx: LongTensor
    ) -> Tuple[LongTensor, LongTensor, BoolTensor, BoolTensor]:
        # idx is a Tensor of indicies of the corpus, eg. [2342,12312312,34534,1]
        # we will interpret these as the id of the center word
        idx += self.pad_size
        # Idx is repeated to get the sliding window effect
        # For the sliding window part we add the range with idx
        window_range = torch.arange(-self.window_size, self.window_size + 1)
        idx = idx.unsqueeze(1) + window_range.unsqueeze(0)


        # idx = torch.transpose(idx.repeat(2*self.window_size+1,1), 0, 1)
        # idx = idx + torch.arange(-self.window_size, self.window_size+1)


        # Get the middle slice for the center
        # The rest of them are context
        center = self.corpus[idx[:, self.window_size]]
        context = self.corpus[
            torch.cat(
                (idx[:, : self.window_size], idx[:, self.window_size + 1 :]), dim=1
            )
        ]
        # Get do the subsampling.
        center = self.sub_sample_words(center)
        context = self.sub_sample_words(context)


        # Get rid of the dataset that has the center word as <pad>.
        # Or has all context words as <pad>.
        if not self.eos_mask:
            keep = (center != self.pad_id) & (context != self.pad_id).any(dim=-1)
            center = center[keep]
            context = context[keep]
            assert (center != self.pad_id).all()
            context_mask = torch.ones_like(context)
        else:
            keep = (
                (center != self.pad_id)
                & (context != self.pad_id).any(dim=-1)
                & (center != self.eos_token)
                & (context != self.eos_token).any(dim=-1)
            )
            center = center[keep]
            context = context[keep]
            assert (center != self.pad_id).all()
            context_mask = self.get_mask(context)
            # Mask might do away with the whole sentence. In that case remove that
            keep = (context_mask != False).any(dim=1).squeeze()
            center = center[keep]
            context = context[keep]
            context_mask = context_mask[keep]
        return {
            "center_word": center,
            "context_words": context,
            "context_mask": context_mask,
        }

上記のコードについてはよく理解できていませんが、訓練用データセットtrain_iterとして使用されるのはこの返り値となっています。

  • center_word: size() == [B]
  • context_words: size() == [B, window_size]
  • context_mask: size() == [B, window_size]
for i, batch in enumerate(train_iter):
    print(TEXT.itos[batch["center_word"][0]])
    context_words = [TEXT.itos[word] for word in batch["context_words"][0]]
    print(context_words)
    print(batch["context_mask"][0].to('cpu').detach().numpy().tolist())

"""
2021年
['日本語', ').', 'NHK', 'みんなのうた', '.', '3月20日', '閲覧', '<pad>', '<eos>', '大槻']
[True, True, True, True, True, True, True, False, False, False]

調査
['<pad>', '生息', '実態', '<pad>', '聞き込み', 'による', '過去', '<pad>', '生息', '情報']
[False, True, True, False, True, True, True, False, True, True]

2020年
['<pad>', '地下鉄', '', '##ぎょうせい##', '<pad>', '10月14日', '<pad>', '17', '', '<pad>']
[False, True, True, True, False, True, False, True, True, False]

筋力トレーニング
['効果', '<pad>', '上げる', 'ため', '', '<pad>', '<pad>', '', '##1990年##', '<pad>']
[True, False, True, True, True, False, False, True, True, False]
tensor(33, device='cuda:0')
"""

(2023/01/26 追記)
データローダーの中身を上記のコードで調べてみました。窓幅を5にしているので、周辺語は10個になっており、中心語は周辺語のちょうど真ん中の位置にあるようになっています。また、<pad><eos>はマスクでFalseになっています。

また、<eos>トークンが周辺語の前半にある場合、<eos>トークンを含むそれ以前の文がFalseにマスクされ、<eos>トークンが周辺語の後半にある場合、<eos>トークンを含むそれ以降の文がFalseにマスクされています。

モデルの実装・訓練

train.Trainer.TrainerWordSimilarityクラスにあるtrain_modelでモデルの実装・訓練が行われています。train_modelの詳細についてみてましょう。

optimizerの設定

## Setting up the optimizers
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=self.lr)

filter関数によって、勾配を保存している parameter のみを抽出して、optimizer に Adam を使用しています。

negative sampling

# Create negative samples for the batch
batch = self.to(batch, device)
batch = self.add_negatives(batch)

add_negativesによって negative samplingが行われています。add_negativesにはnegative_samplingファイル内のRandomNegativeCBOWクラスかRandomNegativeSkipGramクラスが格納されています。

RandomNegativeCBOW

class RandomNegativeCBOW:
    """
    This augments a batch of data to include randomly sampled target center words.
    Appends the sampled words with 'center_words' of the batch
    """


    def __init__(self, number_of_samples: int = 5, sampling_distn: LongTensor = None):
        self.number_of_samples = number_of_samples
        self.sampling_distn = sampling_distn


    def __call__(self, batch) -> LongTensor:
        x, y = batch["context_words"].shape
        negatives = torch.multinomial(
            self.sampling_distn,
            num_samples=self.number_of_samples * x,
            replacement=True,
        ).resize(x, self.number_of_samples)
        batch["center_word"] = torch.cat(
            (batch["center_word"].unsqueeze(1), negatives), dim=-1
        )
        return batch

RandomNegativeCBOWでは入力されたbatchcenter_wordを更新します。

  • 入力前: center_word.size() == [B] (Bはバッチサイズ)
  • 入力後: center_word.size() == [B, ns+1] (nsはネガティブサンプリング数)

偽の中心語がネガティブサンプリングされた形式になります。

RandomNegativeSkipGram

RandomNegativeSkipGramクラスでは偽のcontextsをネガティブサンプリングする形式となっています。デフォルトでは前述のCBOWの方式で訓練を行なっています。

スコアの計算(順伝搬)

# Start the optimization
optimizer.zero_grad()
score = model.forward(
    batch["center_word"],
    batch["context_words"],
    batch["context_mask"],
    train=True,
)

順伝搬をどのように計算しているかを詳しく見てみましょう。使用しているモデルはWord2Boxを継承しているWord2BoxConjunctionとなっているので、まずWord2Boxの実装についてみていこうと思います。

Word2Box (埋め込み表現の初期化)

class Word2Box(BaseModule):
    def __init__(
        self,
        TEXT=None,
        embedding_dim=50,
        batch_size=10,
        n_gram=4,
        volume_temp=1.0,
        intersection_temp=1.0,
        box_type="BoxTensor",
        **kwargs
    ):
        super(Word2Box, self).__init__()


        # Model
        self.batch_size = batch_size
        self.n_gram = n_gram
        self.vocab_size = len(TEXT.itos)
        self.embedding_dim = embedding_dim


        # Box features
        self.volume_temp = volume_temp
        self.intersection_temp = intersection_temp
        self.box_type = box_type


        # Create embeddings
        self.embeddings_word = BoxEmbedding(
            self.vocab_size, self.embedding_dim, box_type=box_type
        )
        self.embedding_context = BoxEmbedding(
            self.vocab_size, self.embedding_dim, box_type=box_type
        )

Word2BoxConjunctionに関係がある__init__関数のみ焦点を当てます。
ここでは、BoxEmbeddingクラスにより、vocabularyの埋め込み表現の初期化を行なっています。torch.nn.init.uniform_により一様分布で初期化されています。初期化時のテンソルのサイズは[box_embedding_dim*2]となっており、テンソルの後ろ半分は前半分の値を+0.1した値となっています。

Word2BoxConjunction (順伝搬)

class Word2BoxConjunction(Word2Box):
    def intersect_multiple_box(self, boxes, mask):
        beta = self.intersection_temp
        z = boxes.z.clone()
        Z = boxes.Z.clone()

        # maskが-infなので、対数を取ると限りなく0に近づく
        z[~mask] = float("-inf")
        Z[~mask] = float("inf")
        z = beta * torch.logsumexp(z / beta, dim=1, keepdim=True)
        Z = -beta * torch.logsumexp(-Z / beta, dim=1, keepdim=True)

        return BoxTensor.from_zZ(z, Z)  # [2, box_embedding_dim]

    def forward(self, idx_word, idx_context, mask_context, train=True):
        context_boxes = self.embedding_context(idx_context)  # Batch_size * 2 * dim
        # Notce that the context is not masked yet. Need to mask them as well.

        word_boxes = self.embeddings_word(idx_word)  # Batch_size * ns+1 * 2 * dim
        pooled_context = self.intersect_multiple_box(context_boxes, mask_context)  # Batch_size * 

        if self.intersection_temp == 0.0:
            score = word_boxes.intersection_log_soft_volume(
                pooled_context, temp=self.volume_temp
            )
        else:
            score = word_boxes.gumbel_intersection_log_volume(
                pooled_context,
                volume_temp=self.volume_temp,
                intersection_temp=self.intersection_temp,
            )
        return score

先ほどのWord2Boxを継承しているWord2BoxConjunctionについて詳しくみてます。このクラスのforwardが実際にスコアの計算(順伝搬)をしているところとなります。

intersect_multiple_boxでは複数のBoxの共通部分のBoxを計算するところとなっています。
$$
\operatorname{logsumexp}(x) = \log \sum_i \exp(x_i)
$$
上記の式のように、LogSumExpと呼ばれる方法で共通部分を計算しています。

そしてforwardでは、contextの共通部分であるpooled_contextと中心語のboxであるword_boxesとの類似度(スコア)を計算しています。それぞれのテンソルのサイズは以下の通りになっています。

print(f"idx_context.size()=={idx_context.size()}")
z = context_boxes.z.clone()
print(f"context_boxes.z.size()=={z.size()}")
print(f"idx_word.size()=={idx_word.size()}")
z = word_boxes.z.clone()
print(f"word_boxes.z.size()=={z.size()}")
z = pooled_context.z.clone()
print(f"pooled_context.z.size()=={z.size()}")

"""
idx_context.size()==torch.Size([2767, 10])  # [B, 2*window_size]                                      
context_boxes.size()==torch.Size([2767, 10, 64])  # [B, 2*window_size, embedding_dim]
idx_word.size()==torch.Size([2767, 3])  # [B, ns+1]
word_boxes.size()==torch.Size([2767, 3, 64])  # [B, ns+1, embedding_dim]
pooled_context.z.size()==torch.Size([2767, 1, 64])  # [B, 1, embedding_dim]
"""

次項で2通りのスコアの算出方法について詳細にみてみようと思います。

intersection_log_soft_volume
    def intersection_log_soft_volume(
        self,
        other: TBoxTensor,
        temp: float = 1.0,
        gumbel_beta: float = 1.0,
        bayesian: bool = False,
        scale: Union[float, Tensor] = 1.0,
    ) -> Tensor:
        z, Z = self._intersection(other, gumbel_beta, bayesian)
        vol = self._log_soft_volume(z, Z, temp=temp, scale=scale)

        return vol

self._intersectionで中心語とコンテキストの共通部分を算出しています。

def _intersection(
        self: TBoxTensor,
        other: TBoxTensor,
        gumbel_beta: float = 1.0,
        bayesian: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        t1 = self
        t2 = other

        # ~~~略~~~ #
        
        # intersection_log_soft_volume で使う
        else:
            z = torch.max(t1.z, t2.z)
            Z = torch.min(t1.Z, t2.Z)

        return z, Z

intersection_log_soft_volume中の_intersectionでは、単純にz($X^-$)の最大値とZ($X^+$)の最小値をとって、中心語とコンテキストの共通部分を計算しています。

    @classmethod
    def _log_soft_volume(
        cls, z: Tensor, Z: Tensor, temp: float = 1.0, scale: Union[float, Tensor] = 1.0
    ) -> Tensor:
        eps = torch.finfo(z.dtype).tiny  # type: ignore


        if isinstance(scale, float):
            s = torch.tensor(scale)
        else:
            s = scale


        return torch.sum(
            torch.log(F.softplus(Z - z, beta=temp) + 1e-23), dim=-1
        ) + torch.log(
            s
        )  # need this eps to that the derivative of log does not blow

そして、_log_soft_volumeでは、共通部分のzZをもとにsoftplusという活性化関数を用いてスコアを算出しています。

score = \log\left(\frac{1}{\beta}・\log(1+\exp(\beta・(X^+ - X^-))\right)

softplus関数はReLUと同じような形をしていますが、

  • 関数への入力値が0付近では、出力値が0にはならない
  • 入力値が小さくなればなるほど出力値が0に近づいていき、入力値が大きくなればなるほど出力値が入力値と同じ値に近づいていく

といった特徴があります。

gumbel_intersection_log_volume
    def gumbel_intersection_log_volume(
        self: TBoxTensor,
        other: TBoxTensor,
        volume_temp=1.0,
        intersection_temp: float = 1.0,
        scale=1.0,
    ) -> TBoxTensor:
        z, Z = self._intersection(other, gumbel_beta=intersection_temp, bayesian=True)
        vol = self._log_soft_volume_adjusted(
            z, Z, temp=volume_temp, gumbel_beta=intersection_temp, scale=scale
        )
        return vol

こちらでも前述と同様に_intersectionで中心語とコンテキストの共通部分を算出しています。

    def _intersection(
        self: TBoxTensor,
        other: TBoxTensor,
        gumbel_beta: float = 1.0,
        bayesian: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        t1 = self
        t2 = other


        if bayesian:
            try:
                z = gumbel_beta * torch.logaddexp(
                    t1.z / gumbel_beta, t2.z / gumbel_beta
                )
                z = torch.max(z, torch.max(t1.z, t2.z))
                Z = -gumbel_beta * torch.logaddexp(
                    -t1.Z / gumbel_beta, -t2.Z / gumbel_beta
                )
                Z = torch.min(Z, torch.min(t1.Z, t2.Z))
            except Exception as e:
                print("Gumbel intersection is not possible")
                breakpoint()
        
        # ~~~略~~~ #

        return z, Z

gumbel_intersection_log_volume中の_intersectionではbayesianモードで共通部分を算出しています。数式で表すと以下のようになります。

\begin{align}
X^- &= \beta_g  \log \left(
     \exp \left( \frac{X^-_{t1}}{\beta_g} \right) 
    + \exp \left( \frac{X^-_{t2}}{\beta_g} \right)
\right) \\
X^-  &= \max\left\{ X^-, \max\{X^-_{t1}, X^-_{t2}\} \right\} \\ \\

X^+ &= -\beta_g  \log \left(
     \exp \left( \frac{-X^+_{t1}}{\beta_g} \right) 
    + \exp \left( \frac{-X^+_{t2}}{\beta_g} \right)
\right) \\
X^+  &= \min\left\{ X^+, \min\{X^+_{t1}, X^+_{t2}\} \right\}
\end{align}

Gumbel分布を利用してBoxの頂点を算出しています。Gumbel分布を利用することで、局所識別性(local indentifiability)を向上させることを望んでいます。

    @classmethod
    def _log_soft_volume_adjusted(
        cls,
        z: Tensor,
        Z: Tensor,
        temp: float = 1.0,
        gumbel_beta: float = 1.0,
        scale: Union[float, Tensor] = 1.0,
    ) -> Tensor:
        eps = torch.finfo(z.dtype).tiny  # type: ignore


        if isinstance(scale, float):
            s = torch.tensor(scale)
        else:
            s = scale


        return (
            torch.sum(
                torch.log(
                    F.softplus(Z - z - 2 * euler_gamma * gumbel_beta, beta=temp) + 1e-23
                ),
                dim=-1,
            )
            + torch.log(s)
        )

そして、_log_soft_volume_adjustedでは、前述の_log_soft_volumeと同様に、共通部分のzZをもとにsoftplus関数を用いてスコアを算出しています。

score = \log\left(\frac{1}{\beta}・\log(1+\exp(\beta・(X^+ - X^- -2\beta_{g} \gamma_{\operatorname{euler}}))\right)

Loss function (損失関数)

assert (
    score.shape[-1] == self.negative_samples + 1  # .size()==[B, ns+1]
)  # check the shape of the score


# Score log_intersection_volume (un-normalised) for Word2Box
pos_score = score[..., 0].reshape(  # .size()==[B, 1]
    -1, 1
)  # The first element correspond to the Positive
neg_score = score[..., 1:].reshape(  # .size()==[B, ns]
    -1, self.negative_samples
)  # The rest of the elements are for negative samples

損失を計算する前に positive sample のスコアと negative samples のスコアを分けます。

# Calculate Loss
loss = self.loss_fn(
    pos_score, neg_score, margin=self.margin
)  # Margin is not required for nll or nce
# Handled through kwargs in loss.py
total_loss = torch.sum(loss)
avg_loss = torch.mean(loss)
if torch.isnan(loss).any():
    raise RuntimeError("Loss value is nan :(")

loss_fnとして定義される損失関数はloss.py中にある、nll(Negative Log Likelihood)、nce(Noise Contrastive Estimation)、max_marginの3種類から選択できますが、ソールコード上でも論文でもmax_marginが採用されていました。

max_margin

def max_margin(pos, neg, margin=5.0):
    """
	This is max margin loss for box embeddings.
	Here, the input scores can be un-normalised. The object here
	is to make increse the pos similarity score more than a margin
	from the negative scores. If that margin is satisfied then the
	loss is zero.

	Args:
	    pos: Unnormalised similarity(maybe log in case of Boxes) score for positives.
	    neg: Unnormalised similarity(maybe log in case of Boxes) score for negatives.
	Output:
	    loss =  - max(0, pos - mean(neg) + margin)
	"""
    # Replicate the positive score number of negative sample times
    zero = torch.tensor(0.0).to(device)
    return torch.sum(torch.max(zero, neg - pos + margin), dim=1)
loss = \sum  \max\left( 0, score_{\text{neg}i} - score_{\text{pos}} + \mu \right)

モデルの評価

モデルの評価はtrain.Trainer.TrainerWordSimilarityクラスのmodel_evalで行います。

def word_similarity(self, w1, w2):
    with torch.no_grad():
        word1 = self.embeddings_word(w1)
        word2 = self.embeddings_word(w2)
        if self.intersection_temp == 0.0:
            score = word1.intersection_log_soft_volume(word2, temp=self.volume_temp)
        else:
            score = word1.gumbel_intersection_log_volume(
                word2,
                volume_temp=self.volume_temp,
                intersection_temp=self.intersection_temp,
            )
        return score

評価に使用する予測値は、上記のword_similarity関数の出力値を利用します。2つの単語に対してintersection_log_soft_volume関数もしくは、gumbel_intersection_log_volume関数を用いてスコアの計算を行っています。こちらの計算方法については前述しているのでそちらを見返してみてください。

correlation = spearmanr(predicted_scores, real_scores)[0]

そして、先ほど得た予測値と評価用データセットのラベル値に対して、スピアマンの順位相関係数で評価を行っています。スピアマン順位相関係数は、2変数が順位データや5段階評価データなどの順序尺度の場合に有効な評価手法です。

学習済みモデルの利用方法

モデルを読み込む

モデルを呼び出すための実行ファイルはsrcと同階層に配置しました。
@clickのパラメータの呼び出し方法が理解できていないため、とりあえず不細工ですがこのようにしました。絶対にもっといい方法があると思います。。。
もっといいモデルのloadの仕方わかる方いたら是非ご教示いただきたいです!

# モデルのインスタンス作成に必要なモジュールをインポート
from src.language_modeling_with_boxes.models import Word2Box, Word2Vec, Word2VecPooled, Word2BoxConjunction, Word2Gauss
from src.language_modeling_with_boxes.datasets.utils import get_iter_on_device
from src.language_modeling_with_boxes.__main__ import main

import torch
from torch import LongTensor, BoolTensor, Tensor, IntTensor
import pickle, json


# 保存してあるモデルと同じパラメータを設定 (すごい無理矢理ですが)
config = {
    "batch_size": 4096,
    "box_type": "BoxTensor",
    "data_device": "gpu",
    "dataset": "ptb",
    "embedding_dim": 64,
    "eos_mask": True,
    "eval_file": "../data/similarity_datasets/",
    "int_temp": 1.9678289474987882,
    "log_frequency": 10,
    "loss_fn": "max_margin",
    "lr": 0.004204091643267762,
    "margin": 5,
    "model_type": "Word2BoxConjunction",
    "n_gram": 5,
    "negative_samples": 2,
    "num_epochs": 10,
    "subsample_thresh": 0.001,
    "vol_temp": 0.33243242379830407,
    "save_model": "",
    "add_pad": "",
    "save_dir": "results",
}

# 語彙やデータローダーを作成
TEXT, train_iter, val_iter, test_iter, subsampling_prob = get_iter_on_device(
    config["batch_size"],
    config["dataset"],
    config["model_type"],
    config["n_gram"],
    config["subsample_thresh"],
    config["data_device"],
    config["add_pad"],
    config["eos_mask"],
)

# モデルのインスタンスを作成
model = Word2BoxConjunction(
    TEXT=TEXT,
    embedding_dim=config["embedding_dim"],
    batch_size=config["batch_size"],
    n_gram=config["n_gram"],
    intersection_temp=config["int_temp"],
    volume_temp=config["vol_temp"],
    box_type=config["box_type"],
)

# 作成したインスタンスに保存してあるパラメータを読み込む
model.load_state_dict(torch.load('results/best_model.ckpt'))
print(model)

類似度を計算する

word_1 = "dog"
word_2 = "cat"

# 文字列 を ID に変換
word_1_id = TEXT.stoi["dog"]
word_2_id = TEXT.stoi["cat"]

# INT の ID を LongTensorにキャストした後で類似度を計算
similarity = model.word_similarity(LongTensor([word_1_id]), LongTensor([word_2_id]))
print(similarity)

参考文献

18
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?