LoginSignup
2
2

RWKV (Receptance Weighted Key Value)

Last updated at Posted at 2023-09-28

RWKV についての調査記録。論文内での順に従って書いている訳ではないです。
https://arxiv.org/pdf/2305.13048.pdf

2023年5月に発表され、Transformerを凌ぐのではないかと話題のモデル。
著者部分を見ればわかるようにたくさんの人と組織が関わっている研究。Discordで共同研究しており、誰でもDiscordに参加できるようになっているらしい。
Transformerの概念を踏襲しつつも、全く別物なので理解するのが結構大変だった。

RWKVとは?

簡単に言えば、RNNとTransformerのいいとこ取りを試みたモデル。
学習時には大規模な並列化が可能で、推論時には逐次計算することでメモリの消費量を大幅に削減可能。

【RNN】
・Hidden stateベクトルとトークンを受け取り、Hidden state ベクトルを更新していく
・Last hidden stateを使用してテキスト分類や次トークン予測が可能
❌ 入力シーケンスを一度に1トークンずつ順番に処理するため、並列化できない

【Transformer】
・各トークンからQuery, Key, Valueベクトルを計算し、全トークンペア間のattentionを計算する
・全トークンを並列に処理できるので効率的にtrainingでき、大規模なモデルはRNNより高性能
❌ 処理するシーケンスの長さに対して二次関数的にスケールするため、コンテキストサイズが制限される。
❌ テキスト生成の際には、全トークンのattentionベクトルをメモリに保持する必要があるので、単一のhidden stateしか保存しないRNNよりも膨大なメモリが必要。

【RWKV】
・訓練では、トークン数に応じて線形にスケールする(一種の) Attentionを使用することで大規模な並列化が可能。
・推論では、単一のhidden stateを持つRNNのように動作する。すなわちトークンと状態を受け取り、次のトークンと新しい状態に対する確率分布を出力する関数として働く。並列化は不可だが、トークンが増えるほど大量の計算結果を保持しなければならないTransformerと違って省メモリでサクサク動く。
RNNの拡張性とTransformerの並列性を得ている。 実際、コンテキスト長8192で訓練されたRWKVモデルは、コンテキスト長1024モデルと同等の速度で動き、必要なRAMのサイズも同じ。
理論的には無限のコンテキスト長を扱える(訓練時よりもはるかに長いコンテキスト長に汎化するのは難しいだろうとは述べられていた)。

image.png
まあ図を並べたところでどういいとこ取りしたかはよくわからないですね...
とりあえず、RWKVの基本構造はTransformerではなくRNNだということを把握してください。

先行研究

Transformerベースの視覚タスクにおいて、self-attentionが以前考えられていたほど必須ではない可能性が先行研究(Tolstikhin et al., 2021; Liu et al.,2021; Yu et al., 2022)により示唆されている。
とはいえ、自然言語タスクにおいてself-attentionを完全に廃止しちゃう度胸もなかったので、本研究では、固定的なQKV式をKVに置き換え、時間減衰因子Wを新たに導入することで、attentionメカニズムを部分的に解体した。このアプローチにより、MLP-mixer(Tolstikhinet al., 2021)に似たトークンとチャンネル混合コンポーネントと、gMLP(Liu et al., 2021)に似たゲーティングユニットRを組み込むことができる。

また、本研究はAFT(An attention free transformer. Zhai et.al., 2021)というモデルに着想を得ている。
なのでTransformerのことは一旦忘れてAFTを頭に入れとくとよい。

【AFT】
・全体でfixedなw (Interaction matrix)を持つ。これは学習済みのペアワイズの位置バイアスを表す。   
 sequence length×sequence lengthの行列なので、語が何であるかは関係なく、sequence length内の位置関係のみで注目度を決めている。
・各語についてkを持つ。これはwで決められたfixedな値からの、語ごとのmodulationとして働く。
  語によって、幾つ手前のワードにどれくらいattentionするかを決めている。
image.png

コードを見ながらRWKVの構造を把握

学習済みモデルを使用した推論(文生成)過程 において、"My name is" に続く単語を予測する状況を例とする。 ※パラメータは固定とみなす。推論過程なので並列処理はできない。

まず、コード・モデル概念図・アーキテクチャ図を示す。それぞれの中で緑の枠で囲まれた部分がRWKVユニットに当たる。
image.png

以下は処理の流れ。赤く囲った部分が右の図にあたる部分。
image.png

次に、RWKVレイヤー(RWKV関数)部分についてのみ詳しく見る。
RWKVレイヤーにはtoken(単語)が1つずつ順に入力されるのがポイント。各tokenをモデルで扱う際は、tokenizerでencodeして数値に変換する。この例では、"My name is" は [3220, 1416, 310]とencodeされる。
続く単語を予測させようとしているので、"My" "name" "is"すなわち[3220, 1416, 310]を順にRWKVレイヤーに入力してprobsとstateを更新する。

"My"をRWKVレイヤーに入力すると、Input Embeddingの3220行目にあたるベクトル、すなわち"My"の埋め込み表現が取り出され、それを内部でTime MixingとChannel Mixingブロックの処理にかけてprobとstateを出力するという流れ。
Input Embeddingは、vocab size × dimension(論文中ではchannelと呼ばれる) の大きさで、trainの結果得られるもの。
image.png

さてTime MixingとChannel Mixingブロックでは何が起きているのだろうか。さらに細かい図にすると以下のようになる。
まずTime Mixingブロックに入り、wkvなる値を計算してσでゲーティング(LSTM等でいうForget Gate)しFeed Forward。次にChannel Mixingブロックに入りKを計算、そこからVを計算してσでゲーティングという流れ。
Time MixingのR,K,VとChannel MixingのR,K,Vはそれぞれ別物なことに注意。

ちなみに論文では、
・Time Mixing: Transformerではmulti-head attentionが担う部分
・Channel Mixing: Transformerではfeed-forward networkが担う部分
と述べられている。
image.png
コードと対応させるとこんな感じ。 ※ Wk, Wv, Wr は 1024 × 1024の行列
image.png

【wkvとは】
wkviは, kに従った重みを持つvの加重平均。単なる加重平均ではなく、現在のvにはボーナス(u)が追加され、過去のvは、遠ければ遠いほど減衰する重みが与えられている。減衰する重みが与えられているというのは、モデル全体で一つのdecay vector wを共有(これはAFTと似た概念だ)し、タイムステップが現在着目しているtokenより遠く過去になるほど線形に減衰する ということ。
変数はすべて1024次元(1024のchannelがある)で、すべてのchannelは互いに独立に計算される。
image.png

ここでTransformerのAttentionの式を見てみると、これもvのweighted sumであるが、
wkviは相互作用がスカラー間なので、二次コストになるTransformerと違い、線形のコストで
Attention(Q, K, V)を模倣した役割を果たすというのが強みである。
※RWKVにおけるl,vというネーミングはあくまでもTransformerに模倣して見せるためであって、同じ働きをしている訳ではないことに留意。
image.png

2つのMixingブロックを抜けたら、あとは確率の値にしてやるだけ。
image.png

Trainingについて

テキストに対する予測確率のCross entropy lossで全パラメータに対する勾配を計算する。パラメータ更新はAdam。
訓練では並列化が可能というのがRWKVのウリ。
線形なので、以下のように
「My → name」のTime Mixingブロックの計算を完了した時点で、「name → is」のTime Mixingブロックの計算を始められる。というのもあるし、
image.png

Githubリポジトリで、並列に訓練できる理由が以下のように述べられている。
"Because the time-decay of each channel is data-independent and trainable."
これはすなわち、
numやdenは
・j(文章内の位置)
・w(モデル共通のパラメータ)
・k(xと一つ前のxからリニアに算出される) 
で決まる。
入力トークンが最初から最後までわかっているtrain時であれば、
順番に計算していく必要がない。
また、相互作用は、あるタイムステップ内では乗法的だが、他のタイムステップでは加算的なので、
各タイムステップ同士が独立して学習できる という意味でもある。

その他の最適化

カスタムカーネル:
標準的な深層学習フレームワークを使用した場合のタスクの連続的な性質によるWKV計算の非効率性に対処するため、トレーニングアクセラレータで単一の計算カーネルを起動するように、カスタムCUDAカーネルを実装した。それ以外の部分はすべて行列の乗算やポイントワイズ演算で、効率的に並列化することが可能である。

小さな初期埋め込み:
Transformer(Vaswani et al, 2017)のトレーニングの初期段階では、埋め込み行列がゆっくりと変化し、モデルが初期のノイズの多い埋め込み状態から逸脱しづらいことが課題である。この問題を軽減するために、我々は、埋め込み行列を小さな値で初期化し、その後、追加のLayerNorm操作を適用するアプローチを提案する。この手法を導入することで、学習プロセスを高速化・安定化。
モデルが初期の小さな埋め込みから素早く移行することで、コンバージェンスが向上することを示した。

カスタム初期化:
先行研究(He et al., 2016; Jumper et al., 2021)の原則に基づき、対称性を崩しながらパラメータをできるだけ同一性マッピングに近い値に初期化して、きれいな情報経路を確保する。ほとんどの重みはゼロに初期化される。線形層にはバイアスは使用されない。具体的な計算式は付録Dに記載されている。初期化の選択は、収束の速度と品質に大きな影響を与えることを示した。
image.png

評価

以下の項目について評価した。

  • RQ1: RWKVは、同じ数のパラメータとトレーニングトークンを持つ二次関数的なTransformerに対して競争力があるか?
  • RQ2: パラメーターの数を増やした場合、RWKVは二次関数的なTransformerに対して競争力を維持できるか?
  • RQ3: RWKVのパラメータを増やすと、オープンソースの二次関数的なTransformerでは処理しきれない文脈の長さに対してRWKVモデルを学習させた場合、言語モデリングの損失は改善されるか?

【RQ1・RQ2】
6つのベンチマークにおいて、RWKVが以下の主要なTransformerより優れていた。
Pythia(2023), OPT(2022), BLOOM(2022)
image.png
【RQ3】
Context lengthを長くするとPileでのテストロスが減少。
→ RWKVは長い文脈情報を効果的に利用できていると言える。
image.png

推論実験

【テキスト生成の速度とメモリ要件を評価】
- float32の精度を使用。
- パラメータ数は、埋め込み層と非埋め込み層の両方を含む、全モデルパラメータ。
- 量子化設定の影響は今後の研究に委ねる
image.png
Transformerに比べて、推論時間も線形の増え方で済んでいることが確認できる。

【ゼロショット性能を比較】
RWKV4-Raven-14B vs ChatGPT(2023.2にアクセス) vs GPT-4
- プロンプト:ChatGPTから適切な応答を受け取れるように手動で選択されたものを全てに対して使用。
image.png
GPT向けのプロンプトをそのまま使うと、RWKVは著しく劣るのだが...
・プロンプトが自然言語である
・RWKVはRNNであるため、命令の内部を振り返ることができない
という点を考慮し、質問→必要な情報 の順にプロンプトの形式を変更した結果、一部のデータセットで品質が大幅に向上した。
RNNベースのアーキテクチャーは過去にさかのぼって以前の情報の重みを再調整することができないため、RWKVモデルはコンテキスト内のコンポーネントの位置により敏感なのである。
image.png

RWKVのLimitations

標準的なTransformersの2次的attentionが維持する完全な情報と比較して、多くの時間ステップにわたって単一のベクトル表現を通すのやはり情報が漏れていく。
リカレントアーキテクチャは以前のトークンを「振り返る」能力を本質的に制限している。学習された時間減衰は情報の損失を防ぐのに役立つが、完全な自己注意に比べるとメカニズム的に限界がある。
本研究のもう一つの限界は、標準的なTransformerモデルと比較してプロンプトエンジニアリングの重要性が増していることである。RWKVで使用されている線形注意メカニズムは、プロンプトからモデルの継続に引き継がれる情報を制限する。その結果、注意深く設計されたプロンプトは、モデルがタスクをうまくこなすためにさらに重要になる可能性がある。

Ref

開発者による解説ブログ
https://johanwind.github.io/2023/03/23/rwkv_overview.html
https://johanwind.github.io/2023/03/23/rwkv_details.html

分かりやすい論文読み動画
https://www.youtube.com/watch?v=x8pW19wKfXQ

https://huggingface.co/blog/rwkv
https://izmyon.hatenablog.com/entry/2023/06/06/093438
https://note.com/npaka/n/n8f3c2c491901
https://note.com/hamachi_jp/n/n2c971e07db63#2446bfde-c25b-4b36-a7f6-64afc0a9f7ed
https://zenn.dev/jow/articles/f66d6403b9a509
https://zenn.dev/hikettei/articles/5d6c1318998411
https://gigazine.net/news/20230709-rwkv-language-model/

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2