お久しぶりです。あるいは初めまして。
Yosematです。
今回はGPT-3に代表される大規模言語モデルがなぜプロンプトによってあそこまで高精度な自然言語を繰り出すことができるのかについて分析したMicrosoftの最新の論文を解説します。
論文リンク:Why Can GPT Learn In-Context? Language Models Secretly Perform Gradient Descent as Meta-Optimizers。
忙しい人向け
- Attentionモデルにプロンプトを与えることは近似的には線形モデルをバックプロパゲーションすることと同じだった
なぜこの論文が重要か(Yosemat解釈)
この論文が重要になってくるのは今後言語モデルはファインチューニングではなくプロンプトによって学習させることが主流になっていくからです。OpenAIなどの言語モデルもモデル自体は公開されずプロンプト用のAPIが公開されるのみですし、モデルが公開されたところで学習を行うことは容易ではありません。はっきりいって大企業以外にとって言語の先端技術を開拓する方法はプロンプトを除いて閉ざされています(言い過ぎ?)。
しかしどうやらプロンプトでもファインチューニング(実際に学習させる行為)と同じようなことが起こっているんじゃないか、というのが今回の論文なのです。
プロンプトはなぜワークするのか
プロンプトとは
プロンプトとはなんだったでしょうか。それは言語モデルの出力を意図したものに制御するための入力のことです。
現在公開されている最強言語AIであるChatGPTに「言語モデルにおけるプロンプトとはなんですか?」というプロンプトを入力してみましょう。モデルはこれを受けて出力を生成します。
ここでは簡単のために次のような典型的なプロンプトを考えます。
// プロンプトテキスト(例示)
りんご⇒赤
メロン⇒黄緑
なす⇒青
そのあとで次のようなクエリを与えることでモデルから出力を引き出します。
// クエリテキスト
バナナ⇒
言語モデルは例示を踏まえて次に来るべきテキストを作成します。GPT-3のような優れた言語モデルなら「黄色」やそれに準ずる回答をしてくれます。
プロンプトの中身はAttention
プロンプトを踏まえるというのは結局のところSource-Target AttentionのSourceとしてプロンプトテキストとクエリテキストをつなげたものを、Targetとしてクエリテキストを使うことを指します。Source-Target Attentionについてはこちらの大変わかりやすい記事を参照。
プロンプト$X'$とクエリ$X$が入力された時はシンプルに$X'$と$X$をつなげたものがSourceとなるのでプロンプトを使った言語モデル$F_{ICL}$(In-Context Learning)は
F_{ICL}(q) = W_V [X'; X] softmax \left(\frac{(W_K [X'; X])^T q}{\sqrt{d}}\right)
q = W_Q X
とかけます。$[X'; X]$は行列の結合です。
AttentionモデルのPromptingとLinearLayerのFineTuningは(ほぼ)一緒だった!?
ICLモデルの簡単化
分析を簡略化するためにAttentionのsoftmaxとScaling Factor$\sqrt{d}$を取り除きます。
するとICLモデルは
\begin{align}
F_{ICL}(q) &= Attn(V, K, q) \\
&\approx W_V[X'; X](W_K[X'; X])^Tq \\
&= W_VX(W_KX)^Tq + W_VX'(W_KX')^Tq \\
&=\widetilde{F}_{ICL}(q)
\end{align}
とかけます。
Zero-Shot Learning(ZSL)のケースを考えましょう。$X'$の項が消えてなくなりますので$W_{ZSL} = W_VX(W_KX)^T$と定義しても問題ないはずです。
\Delta W_{ICL} = W_VX'(W_KX')^T
とするとICLモデルはZSLモデルの重み行列を変化させたものととらえることができて
\begin{align}
\widetilde{F}_{ICL}(q) &= W_{ZSL}q + W_VX'(W_KX')^Tq \\
&= W_{ZSL}q + \Delta W_{ICL}q \\
&= (W_{ZSL} + \Delta W_{ICL})q \\
\end{align} \tag{1}
とかけます。
また同じことを別の書き方をすれば
\begin{align}
\widetilde{F}_{ICL}(q) &= W_{ZSL}q + W_VX'(W_KX')^Tq \\
&= W_{ZSL}q + LinearAttn(W_VX', W_KX', q)
\end{align} \tag{2}
となります。
線形モデルの学習を重み変化の累積で表記
頭をクリアにして、次は線形レイヤー(kerasでいうところのDense, torchでいうところのLinear)の話をしましょう。
線形レイヤーは当然
F(x) = Wx
とかかれますよね。
今何度かのバックプロパゲーションによって重み行列$W_0$が$W_0 + \Delta W$へと変化したとしましょう。
線形レイヤーは
F(x) = (W_0 + \Delta W)x \tag{3}
と書かれます。
ところでバックプロパゲーションによる重み行列の更新は
合成関数の連鎖率を使えば
\frac
{\partial J}
{\partial W} = \frac{\partial J}{\partial F} \otimes \frac{\partial F}{\partial W}
と書けます。外積回りの線形代数につまずく人は短期的にギブアップするのがおすすめです。$J$は損失関数です。
$e = \frac{\partial J}{\partial F}$を後のレイヤーから伝播されてきた勾配だと解釈することで勾配が逆伝播されていく意味を理解できますね。
線形レイヤーにおいては
\frac{\partial F}{\partial W} = x
ですから重み行列の変化は
\Delta W = \sum_i e_i \otimes {x'_i}^T
となります。これまでに後段のレイヤーから降りてきた勾配とその時の入力の累積和になってるところが注目です。
線形レイヤーは結局
\begin{align}
F(x) &= (W_0 + \Delta W)x \\
&= W_0x + \sum_i(e_i \otimes {x'_i}^T)x \\
&= W_0x + \sum_ie_i({x'_i}^Tx) \\
&= W_0x + EX'^Tx \\
&= W_0x + LinearAttn(E, X', x)
\end{align} \tag{4}
となります。LinearAttentionのValueとして渡される値は過去にわたって受け取った勾配
プロンプトとのアテンションはひそかに重みを更新していた
(1)と(3)の比較から
$\Delta W_{ICL} = W_VX'(W_KX')^T$は線形レイヤーにおけるバックプロパゲーションにおける重みの累積変化$W_0$と等価であることがわかります。
(2)と(4)をみてみると$LinearAttn(E, X', x)$と$LinearAttn(W_VX', W_KX', q)$の対応関係から
$W_VX'$は線形モデルへの勾配$E$に対応しています。論文中ではこの$W_VX'$をメタ勾配と呼んでいます。
代数で遊んでいてよくわからなくなってきたのでもともとのお話に戻ってみましょう。
$X'$はプロンプトテキストでした。そして今回の分析でプロンプトテキストから計算される$W_VX'$は線形レイヤーの更新式における重みと対応付けられることがわかりました。
この分析からAttentionモデルにプロンプトを与えることは線形レイヤーに学習データを与えることと同じということが示唆されます。
もちろんこれまでいくつかの近似を用いてきましたので厳密には異なりますが、ざっくり同じなのです。
大きな解釈をすればわざわざOpenAIのように大規模モデルをTrainingしなくても彼らの公開しているAPIに学習データに相当する量のデータをプロンプトとして与えればよいということがわかります。さらに他のドメインのFew-Shot Learningの在り方も問われてくることになりそうです。
まとめ
今回はAttentionモデルのプロンプトがワークするのは実は暗黙のうちにファインチューニングと等価な計算をしているからであるというMicrosoftの研究について解説しました。
論文ではさらに踏み込んでいて
- これらが近似を乗り越えてどの程度妥当なのか検証
- バックプロパゲーションでMomentumが有効なのだからAttentionにもMomentum入れてみたら精度改善
等に関するディスカッションをしています。
とても有意義なので読んでみてください。
今後も深層学習に関する技術などを配信していきたいと思います。とても励みになるので「いいね」つけていただけますと幸いです。