5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

概要

PyTorchの自然言語処理をしていると、EmbeddingBagというやつが出てくるので、これは何?という話。
超初歩的な話なので、詳しい方は見なくて大丈夫です。

時間がない人向け

Embedding: Outputのshape=$(InputLength \times EmbeddingDim)$
$$
Output = x_{one-hot} \cdot W
$$

EmbeddingBag: Outputのshape=$(EmbeddingDim)$
$$
Output = sum_i(x_{one-hot} \cdot W)
$$

解説

まずは何が行われているのかを簡単に解説します。

変数の定義

\begin{align}
& Input \quad x: \quad(InputLength)\\
& Embedding Weight \quad W: \quad(VocabLength \times EmbeddingDim)\\
& Output \quad y
\end{align}

入力

文章をインデックス化した、トークン列。下記例だと[0,3,1,2]を渡せば良い。
image.png

内部の演算的には、入力xは下記のような、行:入力単語数x列:Vocabulary数のOne-Hotな行列に変換される。
下記例はVocabulary数が5で、入力長さが[i read a book]の4の場合。

x_{one-hot}=\begin{pmatrix}
1 & 0 & 0 & 0 & 0\\
0 & 0 & 0 & 1 & 0 \\
0 & 1 & 0 & 0 & 0 \\
0 & 0 & 1 & 0 & 0
\end{pmatrix}

Embeddingの計算

ここはシンプルに埋め込み行列との掛け算が行われる。埋め込み行列は行:Vocabolary数x列:EmbeddingDim数の行列である。これと入力xとのいわゆる内積で計算が行われる。例えば上記例のようにVocabolary5、EmbeddingDimが2の場合は、下記のようなWeightMatrixとなる。

W=\begin{pmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22} \\
w_{31} & w_{32} \\
w_{41} & w_{42} \\
w_{51} & w_{52} \\
\end{pmatrix}

この出力yは、行:InputLength入力、列:EmbeddingDimとなり、要素を書けば下記のようになる。

y = x_{one-hot} \cdot W\\
=\begin{pmatrix}
w_{11} & w_{12} \\
w_{41} & w_{42} \\
w_{21} & w_{22} \\
w_{31} & w_{32}
\end{pmatrix}

この出力が通常のEmbedding。

EmbeddingBagの計算

公式のDocumentはこちら
EmbeddingBagは全く難しいことはなく、この最後に得られたyをもう1段階処理列方向にを行うもの。例えばデフォルトではmode="sum"となっているが、その場合はこうなる。

output=\begin{pmatrix}
w_{11}+w_{41}+w_{21}+w_{31} & w_{12}+w_{42}+w_{22}+w_{32}
\end{pmatrix}
=\Sigma_i(y_{ij})

語順に関係なく、全部足し合わせたものである。
もう想像に難くないが、mode="max"とした場合には、上記の$\Sigma$がmax演算となる。

確認

実際に動かして動作を見てみます。

Embeddingの確認

実際に上の例をやってみることにする。
再現できるように、Weightはseedで固定している。

import torch
from torch import nn

# Embedding
seed = 0
vocab_size = 5
embed_dim = 2

embedding = nn.Embedding(vocab_size, embed_dim)
torch.manual_seed(seed)
embedding.weight.data.uniform_(-1, 1)
init_weight = embedding.weight.data

print(embedding.weight)
embedding.weightの表示結果
Parameter containing:
tensor([[-0.0075,  0.5364],
        [-0.8230, -0.7359],
        [-0.3852,  0.2682],
        [-0.0198,  0.7929],
        [-0.0887,  0.2646]], requires_grad=True)

実際に[0,3,1,2]の文章を埋め込んでみる。
上の結果と合わせて見れば、weightの0,3,1,2番目がOne-Hotnによって取得されているのがよくわかる。

text_idx = torch.LongTensor([0, 3, 1, 2])
print(embedding(text_idx))
embedding(text_idx)の表示結果
tensor([[-0.0075,  0.5364],
        [-0.0198,  0.7929],
        [-0.8230, -0.7359],
        [-0.3852,  0.2682]], grad_fn=<EmbeddingBackward0>)

EmbeddingBagの確認

mode="sum"で解説したとおりの確認を行う。

# EmbeddingBag
mode = "sum"
embedding_bag = nn.EmbeddingBag(vocab_size, embed_dim, mode=mode)
embedding_bag.weight.data = init_weight

offset = torch.LongTensor([0])
print("EmbeddingBagの結果")
print(embedding_bag(text_idx, offset))

print("Embedding+加算の結果")
print(embedding(text_idx).sum(axis=0))

まったく同様の結果となることが確認できる。

出力結果
EmbeddingBagの結果
tensor([[-1.2355,  0.8616]], grad_fn=<EmbeddingBagBackward0>)
Embedding+加算の結果
tensor([-1.2355,  0.8616], grad_fn=<SumBackward1>)

offsetって何?

最後にEmbeddingBagのoffsetって何?という話。EmbeddingBagは1次元の入力を、offset位置によって区切ってembedingすることが可能。
2つの例でご紹介。

  • 例えば一文字ずつ埋め込みたいなら、長さ分だけの配列を用意してあげる。上の例ならoffset=[0,1,2,3]として叩くだけ。
offset = torch.LongTensor([0,1,2,3])
print(embedding_bag(text_idx, offset))

>tensor([[-0.0075,  0.5364],
        [-0.0198,  0.7929],
        [-0.8230, -0.7359],
        [-0.3852,  0.2682]], grad_fn=<EmbeddingBagBackward0>)
  • 例えば複数の文を全部flatに繋いだあとに、文頭をoffsetに指定してあげれば、勝手に切り取って埋め込んでくれる。PyTorchのExampleはこの例かな。イメージはこういう感じ。このときの設定は、offset=[0, 4]となる。(Vocabolaryが増えてるじゃないかというツッコミはなしで。)
    image.png
5
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?