1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Transformerのメモ

Last updated at Posted at 2019-03-30

注意

合ってるの?って言われるとわからない

元の論文

他の人の解説

超詳しい解説、コードを書きながら何やってるかわかる(日本語):
https://qiita.com/halhorn/items/c91497522be27bde17ce

論文とコードが一緒に読める(英語):
http://nlp.seas.harvard.edu/2018/04/03/attention.html

大枠から解説(英語):
http://jalammar.github.io/illustrated-transformer/

何が新しいのか・何が今までどおりなのか

ここで言う「今まで」とはRNN系列のモデルを指します
具体的にはAttention LSTM
モデルとしては[Effective Approaches to Attention-based Neural Machine Translation] (http://aclweb.org/anthology/D15-1166)の図がわかりやすい

今までどおりなところ

1. Decoderの中間層を利用してEncoderの各時刻(各token)の出力に対して重みを計算し、

2. Encoderの出力の重み付け和をDecoderの中間層に合成(concat/sum)してる

ところ

AttentionLSTMではDecoderの中間層とconcat
image.png
参照:http://aclweb.org/anthology/D15-1166

TransformerではDecoderの中間層とsum
image.png
参照:http://jalammar.github.io/illustrated-transformer/

推論時の計算

LSTMでもTransformerでも時系列順に計算する必要がある
どういうことかというと、Transformerであっても、推論を時間ごとに並列化して計算するのは無理・・・だと思う
なぜなら、出力tokenを$t_1,...t_n$としたとき、token $t_i$の出力確率を計算するには入力にtoken $t_1 ... t_{i-1}$ が必要になるからである

図でいうと下の通り
image.png
参照:http://jalammar.github.io/illustrated-transformer/

何が違うのか

学習時の計算

Transformerは学習を並列して行うことができる

並列って具体的にどういうことかというと、
入力token列$t_{i_1}...t_{i_k}$を利用したtoken$t_{o_1}$の出力確率の計算と、
入力token列$t_{i_1}...t_{i_k}$と出力token列$t_{o_1}$を利用したtoken$t_{o_2}$の出力確率の計算と、
入力token列$t_{i_1}...t_{i_k}$と出力token列$t_{o_1}...t_{o_2}$を利用したtoken$t_{o_3}$の出力確率の計算と、...
というを全部バラバラに行えるということである
これは何故かというと、出力token列は学習データにおいては既知であるためである

LSTMの場合、token $t_{o_i}$の出力確率を計算するためには、時刻$t_{i-1}$におけるDecoderの中間表現が必要になる**(=計算が前の時刻の計算に依存している)**
よって並列化が出来ないのである

位置の考慮

Transformerでは、各tokenが文章の何番目に当たるかを考慮している
これはPositional Encodingという、次元ごとに異なる周期を持つ信号(sin/cos波)をTokenのEmbeddingに足し合わせることで実現されている、具体的には次の関数
image.png
参考:https://arxiv.org/abs/1706.03762

偶数次元ではsin、奇数次元ではcosの信号になっているようだ
また1周するために必要なtoken数は$2π*10000^{2i/D_{model}}$
ここでiはベクトルの各次元に相当する
iが大きくなればなるほど(ベクトルの高次元に行けばいくほど)一周に時間がかかる(=周期の長い)信号となっていることが数式からわかる

図からもわかる
image.png
参考:http://jalammar.github.io/illustrated-transformer/
ただし、この図では偶数次元/奇数次元ではなくて、全N次元とすると1~N/2次元、N/2+1次元~N次元でsin/cosを分けているようだ

周期的な信号と聞くと「一周したら最初の単語と位置区別できなくなっちゃうじゃん!」と思うが、高次元になればなるほど一周に時間かかるようになってる+各次元で周期が違うので、全ての次元の周期がきれいに初期状態に戻るには相当な時間がかかる(≒実質的にどんな長さの文章でも入力可能)ように思える
(ちゃんと計算してないので不明)

整理してもやっぱりちょっと疑問は残るといえば残る
足し合わせるんじゃなくて単純にconcatじゃいけないんだろうか
concatするのであればわざわざ次元をtoken embeddingに合わせる必要もない
位置情報を保存するための次元をn-bit確保しといてそれに入れるのでは駄目なんだろうか

bit立てる場合だと桁上りの際に一気に違う信号になるので、連続的な変化をするsin波を使う気持ちはわかる
でもconcatじゃ駄目なのか?という気持ちは拭えない

1
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?