こんにちは,Nakaiと申します.
今回はGoogle Researchから2021/04/18にarXivに投稿されたチューニング手法であるPrompt Tuningについてゼミで紹介するので,ついでにQiitaにも投稿させていただこうと思います.
原論文 : The Power of Scale for Parameter-Efficient Prompt Tuning
codeの公開は今の所なさそうです(2021/05/01現在)
数式及び図は基本的に論文から引用しています.
また,私は普段は画像認識の領域に関わっており自然言語処理にはあまり触れたことがないので,原論文やその他の論文,記事を参考にしながら推測しているところがいくつかあります.間違いやニュアンスの違いなどあるかもしれませんが,その時はコメントなどしていただければと思います.よろしくおねがいします.
#論文の概要
- GPT-3やT5のような超大型汎用モデルを凍結し,prompt部分のみを学習してModel Tuningに迫る精度を出すPrompt Tuningを提案
- 従来法であるPrompt Designから細かな調整を排除し,精度やロバスト性が向上したことを確認
- 特にドメインシフト問題に関してはModek Tuningを凌ぐ精度
#はじめに
近年,自然言語処理はBERTやGPT-3のように,何千億何兆ものパラメータを持ち様々なタスクに対応する巨大な汎用型モデルの台頭が目立ってきました.このようなモデルの巨大化に伴い,新規タスクを行う際にいかに簡単に低コストでTuningするかといったことが度々取り上げられてきたようです.他分野でも多く取り上げられているModel Tuningでは,再学習が非常に困難ですし,そもそもモデルを手持ちのマシンに載せないといけません.化け物みたいなマシンを持っているならまだしも,一個人には到底無理です.
そんな中,GPT-3あたりから徐々にPrompt Programing[解説記事],あるいはPrompt Designという考え方が出てきました(Prompt Designはおそらく本論文独自の言い回しかと思います).Prompt Designを説明するために,まずは簡単にPromptについて説明します.
##Promptとは
Promptとは,入力の前にモデルにどんなタスクを課すかを指示する目印のようなものです.
人間にわかりやすく例えると,以下のような感じです.
(ex)翻訳タスクのデータとラベル
データ : translate Japanese to English : こんにちは.
ラベル : Hello.
この例で言う,"translate Japanese to English"がpromptです.
日本語を英語に訳すだけのモデルを作るのであればpromptは必要ありません.しかし,英語を日本語に訳したり,中国語やドイツ語など多言語で訳したり,自然言語の別タスクである質問応答や穴埋めなど様々なことをモデルにさせたければ,promptを用いてモデルに何をしてほしいかを指示することが効果的です.
GPT-3は学習の段階からこのようなデータで学習し,一つのモデルで様々なタスクに対応しようとしています.
##Prompt Designとは
Prompt Designとは簡単に言うと,promptを適切に設計して,モデルを再学習することなく未知のタスクを解こうという試みです.
Promptもマシンにとってはただの数字の羅列です.もともとprompt込みで学習されたモデルでは,この数字の羅列から求められている出力を推測し,適切な出力を出そうとします.モデルが大きくなれば大きくなるほど様々なタスクに対応できる汎用的な特徴抽出器が途中途中に挟まっているので,このpromptをうまいこと設計してあげれば未知のタスクにも対応できるのではないか,ということです.
このアプローチは2020年辺りから活発に研究されていたようですね[1][2][3]
これらの手法は主に,どのようにpromptを設定すればいいか,どのようにして適切なpromptを探索すればいいのかといった内容になっています.
##Prompt Designの欠点
しかし,これらの手法について今回の論文は以下の問題を指摘しています.
- エラーを起こしやすく,エンジニアの技量や探索法による細かい調整が必要
- Promptの有効性がモデルの最大入力長に依存する
- Model Tuningと比べると,精度が遥かに劣る
特に最後は深刻です.以下のグラフが示すようにSuperGLUEという自然言語の様々なタスク(質問応答,自然言語推論,因果推論など)のスコアを総合的に判断するスコアでは,Model Tuningと比べて従来法のGPT-3 Prompt Designは10ポイント以上,大きいところだと25ポイント近く差をつけられています.
この点提案手法であるPrompt Tuningは,特に巨大なモデルでModel Tuningに迫る精度を達成しています.
#Prompt Tuning
ではどのようにpromptを設計すれば,より良い精度を得られるのでしょうか.それは深層学習の歴史が教えてくれます.
従来のPrompt Designは主に人の手や探索アルゴリズムなどによってpromptを設計してきました.しかし,探索や検索を行うアルゴリズムはどれも微分できないものばかりです.さらに,用いられるpromptはもともと持っている埋め込み表現しか用いていません.つまり,人間にとって意味のある語彙の埋め込み表現しか用いていないのです.
そして今回相手にしているのは何千億何兆ものパラメータを持つ巨大なモデルです.
深層学習はあらゆるものを微分可能な関数に置き換え,最適化によって成果を上げてきました.そこで,この論文では微分ができる関数を用いてpromptの独自のパラメータを用意し,最適化してしまおう,というアプローチが取られています.
以下にModel TuningとPrompt Tuningの概要図を示します.
Model TuningではTask一つ一つに対してモデルを用意し,チューニングする必要があります.小規模なモデルであればこれで事足りるのですが,巨大なモデルの場合は簡単ではありません.例えば論文で用いられている自然言語モデルであるT5には最大で110億ものパラメータが必要です.
そこでPrompt Tuningの出番です.Prompt Tuningでは複数のタスクに対してモデルは一つで十分です.モデルの代わりに.タスクごとにprompt用のパラメータを用意し,最適化します.具体的なパラメータ数は,例えばpromptのトークン数(単語数のようなもの)を100,単語の埋め込み表現の次元数を4,096次元とすると,100×4,096=409,600となり,Model Tuningと比べると約26,855分の1と,大幅なパラメータ削減になります.もはやそのへんの軽量モデルよりも少ないですね.
ここまでまとめてきたModel Tuning, Prompt Design, Prompt Tuningをまとめて式にすると,以下のようになります,ここで,$X$,$Y$はそれぞれ入力及び出力であり,$\theta$はモデルのパラメータ,$P$はPrompt,$\theta_P$はPromptのパラメータです.$P$と$p$の違いにご注意ください.
- Model Tuning
$$p_{\theta} (Y|X)$$ - Prompt Design
$$p_{\theta} (Y|[P;X])$$ - Prompt Tuning
$$p_{\theta, \theta_P} (Y|[P;X])$$
Model Tuningはシンプルに,与えられた入力$X$に対して適当な$Y$を返すようにパラメータ$\theta$を再学習します.Prompt Designでは入力$X$にPrompt$P$を追加し,パラメータ$\theta$を固定したモデルに入力し,出力を得ます.Prompt Tuningでは入力$X$にPrompt$P$を追加し,パラメータ$\theta$を固定したモデルに入力することは変わらないのですが,Prompt$P$は学習可能なパラメータ$\theta_P$に依存します.
訓練段階では
- n個のトークン${x_1, x_2, \cdots , x_n}$から埋め込み表現行列$X_e\in R^{n\times e}$を取得
- さらに$\theta_p$から得られるPrompt $P_e\in R^{p\times e}$を$X_e$に連結し,$[P_e ;X_e]\in R^{(n+p)\times e}$を生成
- 通常通りEncoderに入力
-
$\theta_p$の勾配のみ計算し,$Y$の尤度を最大化するように$\theta_P$を更新
ここで,$p$,$e$はそれぞれPromptのトークン数,埋め込み表現行列の次元です.
#結果
Promptの初期化方法,モデルのPretrain方法,Ablation Studyについては少し深い内容になるので,後述の補足に回します.
##従来法との比較
従来法との比較として,Model Tuning, GPT-3でのPrompt Designとの比較を以下の図に示します.
途中でも同じ図を示しましたが,Model Tuningには届かないまでもPrompt Designには大きな差をつけています.またモデルが大きくなり,T5モデルの最大サイズである110億パラメータまで行くと,ほぼModel Tuningに並びました.
##ドメインシフト問題
Model Tuningと比較してどれくらい過適合に強いかも示されています.
ドメインシフト問題とは,簡単に言うとデータに適合しすぎていないかという問題です.同じタスク,同じ分類クラスの2つのデータセットがあったとして,片方で学習したモデルがもう片方のデータセットにも適応できているかという問題です.
以下の表は同じタスク(入力された2つの文章が同じかどうか)をもつ2つのデータセット(QQP,MRPC)に対して,zero-shot(まったく再学習しない)でどこまでの精度が出せるかというものです.
ドメインシフト問題においてはむしろPrompt TuningがModel Tuningより良い結果であることがわかります.つまり,過適合しにくいということですね.
※ただし,ドメインシフト問題には様々な問題設定があるらしく,今回の論文ではあまり深くまではやっていませんでした.
##Prompt Ensembling
複数の出力から多数決などによって最終的な出力を決定するEnsemble学習についても触れられています.今回のPrompt Tuningでは個別にモデルを準備する必要がありません.Promptを準備するだけで十分です.ということで,Ensemble学習もやっていました.
上の表は各タスクにおける5つの学習したPromptのAverage score , Best score, Ensemble scoreをまとめたものです.
T5のような巨大なモデルでも簡単にEnsemble学習ができ,精度を更に上げることができると主張しています.
#まとめ
いかがだったでしょうか.今回は自然言語処理の分野から巨大なネットワークで簡単にかつある程度高精度な結果を得られるPrompt Tuningについてまとめてみました.
自然言語処理の分野以外ではまだ巨大モデルというのはあまり見かけませんが,今後出てきたときに再注目されるTuning手法かもしれません.また,Promptに縛られなくても,出力間際ではなく入力をいじってみようというアプローチはもしかしたらどこかで使えるかもしれません.
また細かい訓練方法などは論文を見る限り,書いていないと思います.私はこれを直接用いるというより,自分の分野に応用できないかという視点で読んでいたの詳しくは目を通していないのですが,もっと詳しく知りたい!という方は,現状日本語では記事が殆どない(私が確認したのは一件のみ,内容もそこまで深くない)ので,論文やgithubの更新を待ったほうが良いかもしれません.
この記事が皆さんのなにかの参考になれば幸いです.
また最近研究,勉強用のTwitterをはじめました.ほんとに最近でまだ全然ですが,よろしければフォローお願いします.
#補足
##Promptの初期化方法
Promptの初期化方法ですが,論文では以下の3つが提案されています
- ランダムに初期化
- 本来のトークン埋め込み表現から,もっともらしい表現を取得
- ほしい出力の語彙から埋め込み表現を取得(クラス分類問題に限る)
1つ目は説明不要かと思います.
2つ目はもともとモデルがもっている語彙の埋め込み表現から,例えば翻訳であればtranslateなどの埋め込み表現を初期値として設定するというものです.
3つ目は(推測になるのですが)入力文章がどのような文章なのかに分類するとき,その目標出力値を初期値として設定してしまおうというものです.
##モデルのpretrain方法
T5のもともとの訓練方法には,今回のPrompt Tuningには不都合なことがあります.
T5は自己教師あり学習で学習し,以下のようなデータ,入力,出力をとっています.
data : ”Thank you for inviting me to your party last week”
input : “Thank you <X> me to your party <Y> week”
output : “<X> for inviting <Y> last <Z>”
input,outputともにまともな英文ではありませんが,モデルはしっかりと文章の特徴は掴んでいるので,出力方法を自在に変えられるModel Tuningであればなんの問題もないわけです.
しかし今回は入力しかいじりません.これはなかなか難しい,ということで以下の3つの手法で比較実験し,結果をAblation studyに示しています.
- Promptのみを変更
- 目標テキストの前全てに<X>をつけ,少しでも本来のT5の出力に近づける
- prefix language modeling(prefix LM)にて再学習
1つ目はそのまま行く,ということです.ただpromptを変えるだけでどこまでいけるか試します.
2つ目は,元のT5の出力に<X>やら<Y>やらがついているので,それがつくようにpromptを学習していきます.学習方法に寄せたといった感じです.
3つ目はprefix LMを用います.prefix MLとは入力および出力済みのテキストに対してattentionをかけ,学習するという学習手法です.要するに完全に再学習です.この手法はT5の論文でもあつかっており,くわしくはそちらをご覧ください.
##Ablation study
Ablation studyでは上のpromptの初期化方法,モデルのpretrain方法の他に,promptの長さ,pretrainの学習ステップ数についてまとめてあります.
左上:Promptの初期化方法について
右上:Promptの長さ
左下:モデルのpretrain方法
右下:prefix LMでpretrainするときの学習step数
Promptの初期化方法については,ランダム以外の2つは似たりよったり,ランダムでもモデルが最小,または最大のときにはあまり変わらないということ.これはなかなか謎ですね.
Promptの長さも同様に短いときはあまりいい精度が得られていません.しかし,モデルが最大になると結構いい感じにもなっています.また,長さが100を超えると大差ないということが読み取れます.
モデルのpretrain方法は,やはりprefix LMが最も良かったようです.だたここでもモデルが最小,または最大のときにはあまり変わらないという謎の現象が起こっています.
最後のpretrainの学習step数については多いほうが良さそうですね.だたここでも上のように,モデルが最小,最大のときに殆ど変わらないと言った現象が起こっています.
これはなかなか謎ですね,まだまだ見るところがありそうです.