大規模言語モデル入門 の輪読会を開催したので、発表に使った資料を一部修正して公開します。
Transformerとは?
Googleが2017年に提案したニューラルネットワークモデルのアーキテクチャ。
Transformer自体はモデルではなく、あくまでモデル内構造の1つ。
Transformerのざっくり背景
- Transformerはなぜ生まれたのか。
- 今まで系列データ(文章など)の処理にはRNNが使われていた。
- RNNなどのSeq2Seqには、主に課題が2つあった。
- 系列の位置情報を考慮しにくい(離れた位置との依存関係を考慮しにくい)
- 学習の高速化が難しい。
- Transformerには3パターン存在するらしい
- エンコーダのみモデル(BERTはこれ)
- 主に使われるタスク:文章分類、クラスタリング、対話モデル
- デコーダのみモデル(GPTはこれ)
- 主に使われるタスク:文章生成、要約生成
- エンコーダ&デコーダモデル
- 主に使われるタスク:機械翻訳、text2code
- エンコーダのみモデル(BERTはこれ)
- Transformerが出現する以前にこれらは存在していたモデルではある。
エンコーダとデコーダ
- エンコーダとは?
- シーケンスを高次元の特徴表現に変換する役割を持つ。
- 文章入れるとただ数値の行列が出てくるイメージ。
- デコーダとは?
- 主にシーケンスの生成を行う役割を持つ。
- デコーダによって、文章が出てくる。
エンコーダとデコーダを構成する仕組み
エンコーダとデコーダを構成には、下記のような手法が組み込まれている。
-
エンコーダ
- 入力トークン埋め込み(Word Embedding)
- 位置符号化(Position Encoding)
- 自己注意機構(Self Attention)
- 加算&層正規化(Add&Layer Normalizationa)
- マルチヘッド注意機構 < こいつがTransformerのミソ
-
デコーダ
- 入力トークン埋め込み(Word Embedding)
- 位置符号化(Position Encoding)
- 自己注意機構(Self Attention)
- 残差結合&層正規化(Add&Layer Normalizationa)
- フィードフォワード層
- マルチヘッド注意機構 < こいつがTransformerのミソ
- 交差注意機構(Cross Attention) < デコーダのみに存在する。
入力トークン埋め込み(word embedding)とは?
- トークンの意味をV×Dの行列で表現する仕組み。
- Vは単語の数
- Dは次元数
- ここのD次元はハイパーパラメータなので、タスクに応じて適切に決める必要があるらしい。次元を増やすほど、表現が豊かになるので精度は上がる可能性が高まる?が、計算時間がかかる。
- ちなみにBERTは768次元、GPT-3は2048次元らしい。
位置符号化(Position Encoding)とは?
- トークンの位置を表現する仕組み。
- 正弦関数を使って位置符号を付与する。
- 正弦関数(sinとかcos)を使うと何がいいのか。
- 相対的な単語の位置を表現しやすい。
- 例えばpos=1とpos=10000の波形が全然ちがう。
- つまり文脈的な関係性はほとんどないと捉えられる。
詳細な内容については、下記の記事がとてもわかりやすいです。
自己注意機構(Self Attention)とは?
- モデルが文脈を理解する仕組み
- 例えば 「マウス🐭」と「マウス🖱️」の意味を判別するために使う。
- そもそもAttentionとはなんぞや?
- Attentionは「重要な要素の重みを高くする機構」と理解すれば良さそう。
- QueryとMemory(Key, Value)が基本となる。
- QueryによってKeyを検索し、Valueを取得する(Pythonのdictに近い)
- このQueryとKeyを同じトークンで構成し、最終的にトークン間の関連性スコアを算出する。
- QueryとKeyは異なる埋め込みが付与されるが、これは同じ埋め込みだと必然的にトークン自身が1番高くなってしまうことに起因する。
マルチヘッド注意機構(Multi-head Attention)とは?
- Attentionを同時に処理するための仕組み。
- Transformerのミソらしい。
- 何がスゴいのか。
- 早く計算できる。
- それぞれのヘッドが異なる特徴を学習し、文脈を理解するようになっている。
- 通常のAttentionの学習では、長文扱うと精度がどんどん落ちていくらしい。
交差注意機構(Cross Attention)とは?
- 役割は自己注意機構と同じだが、エンコーダの情報をMemory(Key, Value)として使い、デコーダの翻訳文をQueryとして入力にいれる。
残差結合と層正規化とは?
- 両方とも学習を安定させるための仕組み
- 残差結合は勾配消失(勾配が小さくなりすぎること)を防ぐ。
- 層正規化は値が大きくなりすぎることによって訓練が不安定になることを防ぐ。
フィードフォワード層
- 非線形活性化関数(ReLUなど)で非線形変換を行うことができる。
- 非線形だと複雑な関数をモデル化しやすいので表現力が上がる。
ドロップアウト
- 一部のベクトルを学習に用いないようにして、過学習を防ぐ仕組み。