さすがにattentionを理解しないといけないと思ったので、
なんとか勉強してまとめました
attentionとは
attentionは、情報源のどこを見れば、問題に答えられるかを学習するモデルです
attentionの重要な要素として、Key, Query, Valueがあります
Keyは、情報源のそれぞれの情報に対応したラベルです
Queryは、問題に答えるために必要な情報は何かです
Valueは、情報源のそれぞれの情報そのものです
つまり、Keyはwebページのタイトル、Valueはwebページの本文です
Queryは、たとえば問題が「もっとも甘い果物は?」として、「果物とは」「甘さとは」「果物 甘さ 一覧」みたいな感じです
ちなみに、attentionは注意機構と訳されます
「1984」に出てきそうなネーミングですね
空気を読めるattention
従来の手法では、文章を構成するトークンをただ加算していました
一方、attentionは、トークン間の関係を計算に入れています
具体的には、それぞれのトークンは、他のトークンとその関連性の積を加算していったものに更新されていきます
つまり、辞書的な意味から文中における意味に変化したということです
今までは辞書を片手に翻訳しており、attentionは参考書やさまざまな図書を横断的に参照して意訳しているようなものです
Source target Attention
具体的に説明していくため、Source target Attentionというモデルをもとに解説します
以降では、情報源をsource, 問題をtargetと呼びます
データの前処理
source、targetのどちらも、そのままではattentionに渡せません
これらのデータは、分散表現というものに変換する必要があります
分散表現については他の記事を参考にしてください
データに含まれるトークン数が$N$、分散表現の次元数が$d_z$とすると、
データの形状は、$(N, 1) => (N, d_z)$に変換されます
Key, Query, Valueはどこからやってくるのか
Key, Query, Valueは以下のように定義されます
\displaylines{
\boldsymbol{K} = \boldsymbol{X_s}\boldsymbol{W}^K \\
\boldsymbol{Q} = \boldsymbol{X_t}\boldsymbol{W}^Q \\
\boldsymbol{V} = \boldsymbol{X_s}\boldsymbol{W}^V \\
}
なお、$\boldsymbol{X}$は、データを分散表現にしたもの
$\boldsymbol{W}$は学習するパラメータで、最初は意味のない数値となっています
$\boldsymbol{X}$と$\boldsymbol{W}$形状は以下のようになっています
\displaylines{
\boldsymbol{X_s} \in \mathbb{R}^{N_s \times d_z} \\
\boldsymbol{X_t} \in \mathbb{R}^{N_t \times d_z} \\
\boldsymbol{W^K} \in \mathbb{R}^{d_z \times d_k} \\
\boldsymbol{W^Q} \in \mathbb{R}^{d_z \times d_q} \\
\boldsymbol{W^V} \in \mathbb{R}^{d_z \times d_v} \\
}
したがって、Key, Query, Valueの形状は以下になります
\displaylines{
\boldsymbol{K} \in \mathbb{R}^{N_s \times d_k} \\
\boldsymbol{Q} \in \mathbb{R}^{N_t \times d_q} \\
\boldsymbol{V} \in \mathbb{R}^{N_s \times d_v} \\
}
なお、便宜のため区別しましたが、$d_q$と$d_k$は同値である必要があるため、
一般に同じ記号で表します
d_kとd_vを同じにすることも多いそうです
検索してみよう!
KeyとQueryの内積を求め、softmax関数に通すことで、必要な情報が何かを知ることができます
検索ワードを打ち込んで、関連度順でwebの記事がヒットするイメージです
数式で表すと以下のような感じです。
\boldsymbol{a} = softmax(\frac{\boldsymbol{Q}\cdot\boldsymbol{K}^T}{\sqrt{d}}) \in \mathbb{R}^{N_t \times N_s}
$\boldsymbol{a}$から必要な情報を取得しましょう
以下のようにします
\boldsymbol{y} = \boldsymbol{a}\boldsymbol{V} \in \mathbb{R}^{N_t \times d_v}
$\boldsymbol{y}$には、必要な情報を中心に要約されたものが入っています
何に使うのか
Source target Attentionは異なる種類の情報をマッピングすることに用いられます
例として翻訳が挙げられます
これは、日本語と英語といった言語間のマッピングをすることで翻訳を可能にしています
「りんご」というQueryを受けて、「apple」に対応するKeyが反応して、「apple」が返される感じで、QueryとKeyの間のマッピングを学習することで翻訳を可能にしています
Self Attention
Source target Attentionでは、sourceとtargetの2つのデータを用いました
Self Attentionでは、1つのデータのみを用います
言い換えれば、sourceとtargetが同じデータということです
そのため、$X_s$,$X_t$を$X$に置き換えるだけでいいです
自分自身を参照するから「Self」ということですね
mask
数理的にはこのくらいしか違わないのですが、実際の用いられ方は少し違います
Self Attentionは、継ぎ足しで文章生成をすることができるのですが、この場合のsourceは生成途中で未完成の自分自身です
したがって、学習時にも未完成のsourceを用いる必要があります
そこで用いられるのが、maskです
以下の文を考えます
Attention is all you need.
学習時、「all」の次を予測させるためには、「Attention is all」までを学習データとして渡す必要があります
「you need」を渡すとカンニングになってしまいます
穴埋め問題なのに穴が埋まっているようなものです
これだと、ただの書き写しになってしまいます
したがって、sourceはtargetにない情報、つまり未来の情報を持ってはいけないのです
何に使うのか
Self Attentionは文章内の構造を把握するのに用いられます
「あれ」というのは何のことか、とか、文章の暗示的な意味を理解できます
「あれ」というQueryを受けて、「りんご」に対応するKeyが反応して、「りんご」が返され、ここからさらに「りんご」=>「庭の木」=>「隣の家」という感じで連鎖し、「あれ」が隣の家の庭で育てられている木になるりんごを指していると機械が理解してくれます
Multihead Attention
これは、Key, Query, Valueの組を複数持ち、それぞれが自動的に別視点で学習していくことで、多角的に文章を理解できる仕組みです
仕組み
Key, Query, Valueの組はheadと呼ばれます
各headは学習の過程で、それぞれ特異な部分を重点的に学習するようになります
例えば、あるheadは文法を理解し、あるheadは単語の意味を理解します
最終的にそれぞれのheadを混ぜ合わせ、多角的に文章を理解する優れたheadを作成します
複数のheadを用いるため、ランダムフォレストと似た性質を持ちます
例えば、1つのheadがおかしな学習をしても最終的な出力に問題が出にくい特徴があります
詳しく
headは以下のように表現します
$\boldsymbol{head} = ( \boldsymbol{K}, \boldsymbol{Q}, \boldsymbol{V} ) \in \mathbb{R}^{N \times d}$
Multihead Attentionでは、headが複数あります
このとき、各headは以下のように分割された形状になります
\displaylines{
\boldsymbol{head}_i \in \mathbb{R}^{N \times d_k} \\
d_k = \frac{d}{h}
}
ここでは、headの個数をhとし、1, ... i ..., hという感じで番号を振っています
各headは、それぞれ独立して学習されるパラメータ$\boldsymbol{W}^K_i$, $\boldsymbol{W}^Q_i$, $\boldsymbol{W}^V_i$をもちます
学習済みのheadを結合し、1つのheadを作成します
つまり、以下のようにします
\displaylines{
\boldsymbol{head}_l = ( \boldsymbol{head}_1, \boldsymbol{head}_2, ... \boldsymbol{head}_h ) \in \mathbb{R}^{N \times d} \\
}
ただし、このままではそれぞれのheadがそれぞれ別の情報を担当している状態になっています
つまり、1箇所の情報に1つのheadのみ関与するということになるので、headが1つのときと比べて性能が上がるどころか、headが別れている分かえって性能が下がります
それぞれのheadがうまく協調してsource全体を処理するようにしたいです
そこで、以下のようにパラメータ$\boldsymbol{W}$をかけて、head同士を混ぜ合わせます
これによって、1箇所の情報に複数のheadが関与してくれるようになります
\displaylines{
\boldsymbol{head} = \boldsymbol{head}_l\boldsymbol{W} \in \mathbb{R}^{N \times d}\\
\boldsymbol{W} \in \mathbb{R}^{d \times d}
}
何に使うのか
Multihead Attentionは文法、語彙、論理性、ニュアンスなど、私たちが自然にしているような、多角的な文章理解をするために用いられます
Source target AttentionやSelf Attentionと競合しないため、それらがさらに高度なタスクを遂行できるようにするための追加機能という立ち位置です
Attentionはあなたを見ている
ここまで紹介してきたattentionという仕組みは、本来翻訳目的で開発されたものですが、予想をはるかに超え、文章の理解・生成、画像の理解・生成、分類など幅広い分野で非常に高い性能を誇っています
例えば、現在AIの象徴として君臨しているGPTのパーツにもMultihead Self Attentionが使われています
現代のAI技術における支配者といっても過言ではないかもしれません
