2023年に公開された論文 RWKV: Reinventing RNNs for the Transformer Era を読んでみた。もしかするとAIに詳しくない人も耳にするかもしれない、現在多用されているAIモデル "Transformer" と、昔ながらのモデル "RNN" を組み合わせて、より精度良くしたものと言える。
この論文は何を言ってるんだろう?何がすごいんだろう?ということが伝わるといいなと、書いてみることにする。
RWKVとは?
RWKV (Receptance Weighted Key Value) は、新たなシステムを導入した機械学習モデルのこと。これを理解するにはまず、TransformerとRNNを知っておく必要がある。
(図の左側がRWKVを導入している部分、図の右側が入力から出力までのモデル)
Transformerとは
自然言語処理に革命をもたらしたとされるモデルである。最近よく聞くChatGPTや生成AIも今やこれが使われている。
自然言語処理の精度が飛躍的に向上したと言われるこのTransformerであるが、配列の長さに応じて必要メモリ量や計算量が2次関数的に増加してしまう。
つまり、Transformerは『高精度だけど複雑』なモデルといったところだろう。
RNNとは
RNN (Recurrent Neural Network) はある入力して出力するだけでなく、その出力をまた学習に戻して、といったように再帰的なシステムを導入したモデルである。以前からあるモデルであり、必要メモリ量や計算量は線形性がある。しかし、並列化等に限界があるためTransformer程の精度を出すことは難しい。
つまり、RNNは『単純だけど低精度』なモデルといったところだろう。
整理すると、
Transformer : 『高精度だけど複雑』
RNN : 『単純だけど低精度』
といったようになっています。これらのいいとこ取りをしたのが、このRWKVである。
何がいいの?
Transformerの『高精度』という部分と、RNNの『低計算量』という2つのメリットを掛け合わせたものである。この2つの観点からいい点を説明することとする。
精度
ChatGPTに使われているモデルをRWKVに変更したところ、性能が44.2%から74.8%に向上した。
その他にも、以下のようにいろいろなタスクについて学習が行われた。
オレンジ色で示すのが、RWKVの正答率である。
計算量
Transformer等と異なり、トークン(単語)数と時間が比例している。Transformerなどではトークン数が増加すると2次関数的に計算時間が増加するため、自然言語処理で扱う文章が長くなると計算時間が大幅に増加してしまう。しかし、RWKVはこのトークン数と時間が線形のため、扱う文章が長くなったとしても計算時間の増加は比例関係である。そのため、長い文章を扱うときにその計算量の速さがメリットとして現れる。
結論
RWKVを使うことにより、Transformerの『高精度』と、RNNの『低計算量』を兼ね備えた新たなモデルとなった。より長い文章を扱えることになったことにより、応用することで今後の自然言語処理の精度向上が見込まれる。