はじめに
みなさん初めまして!
松尾研究室修士1年の髙波亮介です。
今日の記事では、一つの巨大なモデルでイメージキャプショニングからシミュレーション上の制御タスクや実世界ロボットタスクまで解くことを可能にしたGatoを紹介します。
目次
Gatoとは?
Gatoとは、様々なモダリティを持つ情報を処理し、多様なタスクを大規模に学習させたモデルであり、今年DeepMindが発表した基盤モデルの一つになります。
図1:Gatoの全体像(出典:Gatoの論文のFigure 1)
Gatoが学習してるタスクは多岐にわたり、強化学習のタスクとして有名なゲームAtariやイメージキャプショニングタスク、さらには実ロボットを用いたブロックを積み上げるタスク
など様々な領域のタスクを解くことができます。
Gatoは何がすごい?
Gatoのすごい点は一言で言うと
同じニューラルネット、同じ重みで、様々な情報(モダリティ)を処理して、様々なタスクを解くことができる点
にあると思います。それぞれの領域の中でマルチタスクに対応する研究例は、自然言語処理や強化学習などで頻繁に目にすると思いますが、
モダリティや領域超えて大規模に学習する試みはこのGatoが初めてでしょう。
Gatoの概要
それでは実際にGatoの中身を見ていきましょう。
前処理
Gatoで大事なのが、マルチモーダルな情報を統一的に扱うためのトークン化です。各々の種類のデータに対してこのトークン化という前処理を施すことでGatoのモデルに入力する形式をタスクによらず統一することができます。
それぞれのモダリティをどのようにトークン化するか見ていきましょう。
テキスト
まず、テキストですが、SentencePieceという手法で文章を32,000種類のサブワードに分割し、サブワードを各サブワードごとに割り振られた整数IDに変換してトークン化を行います。
つまり、テキストは分割された順番に並んだ整数列にトークン化されます。整数IDは分割サブワード数が32,000なので0以上32,000未満の整数値を取ります。
トークン化がされた後は、モデルの入力次元に合わせるためにトークンから埋め込み表現へと変換されます。トークンから埋め込み表現へは学習済みのルックアップテーブルを用いて変換されます。
最終的に、テキストはサブワードの分割数だけの長さを持つベクトル列に変換されます。
画像
次に画像のトークン化ですが、画像はまず16×16のパッチに分割され、ラスタスキャンの順番に整列されます。そして、それぞれのピクセルを[-1, 1]の範囲で正規化し、最後にパッチサイズの平方根つまり4で割ったものをトークンとしています。
画像もトークン化された後に埋め込み表現へと変換されており、それぞれのパッチはResNetブロックを通してベクトル化されます。最終的には、テキストと同様ベクトル列に変換されています。
図2:画像と離散行動のトークン化(出典:Gatoの論文のFigure 15)
離散値、連続値
最後にAtariのようなゲームにおける行動選択やエージェントの状態、ロボットにおける関節角などの離散値や連続値のトークン化について見ていきます。
まず、ゲームにおける行動選択の内の離散的なもの(例えば、ブロック崩しで言う右に行くボタンを押す)については、row major orderでデータを一次元化した後、そのまま行動に対応した整数IDをトークンとします。IDは0以上1,024未満としています。
次にエージェントの速度やロボットの関節角といった連続的な値は、同様にrow major orderでデータを一次元化した後、[-1, 1]の範囲にない値はmu-lawエンコードという手法を用いてエンコードされます。[-1, 1]の範囲に収められた値は一様の幅で1,024に分割され、整数IDが割り当てられます。最後に離散化されたIDを32,000以上33,024未満の範囲にシフトしています。
こちらもテキストと同様にトークン列を学習済みルックアップテーブルを用いて埋め込み表現へと変換されます。
図3:エージェントの状態と連続行動のトークン化(出典:Gatoの論文のFigure 14)
注意点
強化学習タスクにおいては、状態 (観測) と行動の系列が入力として与えられるのですが、その際の入力系列は、各タイムステップの表現を状態トークン、セパレーター、行動トークンという順番に並べたものをとし、時系列順にこの表現を並べたものになっています。
モデル
続いて、Gatoのモデルについて見ていきます。
Gatoでは主に二つのモデルを学習しており、トークンから埋め込み表現へと変換するモデルと過去のトークン(埋め込み表現)系列から次のトークンの分布を予測する系列モデルがあります。
トークンから埋め込み表現への変換を担うモデルでGatoで学習されているのは画像のパッチを埋め込み表現にするResNetブロックのみであると考えられます (図4) 。テキストや離散値や連続値はおそらく事前に学習されたルックアップテーブルを用いていると考えられます。
図4:画像トークンの埋め込みに用いられるResNetブロック(出典:Gatoの論文のFigure 16)
Gatoのメインのモデルは系列モデルの方にあります。Gatoの系列モデルはタスクの実行主体であり、自然言語処理や画像処理、深層強化学習で近年よく用いられているTransformerのデコーダー部分が採用されています。
GatoのTransformerは、約11.8億パラメータのモデルとなっており、規模としては、rinna社の日本語GPTモデル (約13億パラメータ) に近いです。近年のGPT3といった大規模自然言語処理モデル (約1,750パラメータ) と比べると
小さく感じるかもしれませんが、一般研究者からしたら巨大なモデルであることには変わりません。
表1:Gatoで用いられるTransformerの詳細(出典:Gatoの論文のTable 5)
学習
図5:Gatoの学習の流れ(出典:Gatoの論文のFigure 2)
Gatoでは系列モデルであるTransformerの学習がメインとなっています。Transformerは以下のような損失関数 (式1) で学習されており、過去の系列で条件づけたときに、次の行動、次のテキストとして正しいものを出力するように教師あり学習されます。
式1:損失関数(出典:Gatoの論文の式 (2))
損失関数のmの部分はマスクを表しており、予測したい対象のみ1、それ以外の予測しない物については0を返すようなマスクとなっています。Gatoにおける予測対象とは、制御タスクにおける行動であったり、キャプションのようなターゲットテキストにあたります。
また、それぞれの領域 (自然言語処理や制御) の中で異なるタスクをいくつも取り扱うGatoは、その領域の中のどのタスクに取り組んでいるかということも学習する必要があります。
そこでGatoの学習では、それぞれのバッチ内の系列データ群の内25%に対して、同一教師エージェント同一タスクの系列データからプロンプトとなる系列を一部切り出し、条件付けとして冒頭に挿入しています。
挿入するプロンプトの内、半分はエピソードの終端から取った系列であり、もう半分は、エピソードからランダムに切り出した系列としています。前者はゴール条件付け学習のような効果が期待され、後者はゴールではありませんが、どういったタスクを解いてほしいかを提示する役割があるのではないかと思われます。
データセット
Gatoで用いられるデータセットは主に3つに分けられます。シミュレーション上での制御タスクのデータセット、Vision & Languageタスクのデータセット、ロボットの制御タスクのデータセットの3つです。
これらデータセットの詳細な割合と概要は以下の表2のように論文中で示されています。
表2:データセット(出典:Gatoの論文のTable 1)
シミュレーション上での制御タスクのデータは主にSoTAやSoTAに近い強化学習エージェントによって生成されたものを利用しており、完全に学習が完了したエージェントではなく、学習中のエージェントのデータを利用しています。
Vision & Languageタスクのデータは、純粋にテキストのみのデータもあれば、画像が与えられた時にそれに合致する説明文 (キャプション) を生成するイメージキャプショニング用のデータやVisual Question Answering (VQA) タスクのデータも用いられています。
ロボットの制御タスクのデータは、図9のようなブロック積み上げタスクのセットアップで収集されています。
シミュレーション上でエキスパートエージェントを学習させた後、シミュレーション上でまずエキスパートデータを収集し、そしてsim2realエージェントを使って実機ロボットでもデータを収集しています。
図6:実機ロボットのセットアップ(出典:Gatoの論文のFigure 4)
実験結果
それではここからは上記のような大規模データセットで学習されたGatoの性能評価についてみていきましょう。
イメージキャプショニングとQuestion & Answeringの結果
図7はGatoのイメージキャプショニングの結果です。一部誤ったキャプションを生成していますが、おおむね画像に合ったキャプションが生成されています。
図7:イメージキャプショニングの結果(出典:Gatoの論文のFigure 6)
図8はGatoによるQAタスクの結果です。上段も下段も「フランスの首都はどこか」という質問がGatoに投げかけられていますが、上段は間違った結果を下段では正しい結果を返しています。
このように同じ質問でも異なる結果を返してしまうようですが、モデルをさらにスケールさせることで解決されるであろうと筆者らは考察しています。
図8:GatoによるQAチャットbotの例(出典:Gatoの論文のFigure 7)
シミュレーションでの結果
次にシミュレーション上での制御タスクの結果です。図9の横軸は学習データ中の最高スコアの何%かを表しており、縦軸は横軸のスコア (%) を上回ったタスクの数になります。
つまり、横軸の100%のラインは、Gatoが学習データよりも高いスコアを達成できたタスク数を表していることになります。Gatoのタスクのスコアは50回試行して得られたスコアの平均でもって計算されます。
図9:シミュレーションでの結果(出典:Gatoの論文のFigure 5)
学習データを超えて性能を発揮しているタスク (100%のライン) は、200弱と全604タスクの3分の1程度ですが、学習データの50%スコアを超えて性能を発揮しているタスクは450タスク以上あるという結果になっています。
実機ロボットでの結果
実機ロボットでの結果は以下の通りです。
表3:実機ロボットでの汎化性の評価(出典:Gatoの論文のTable 2)
この表3の結果は、訓練データ中にない物体をテスト時に与えてブロックを積み上げることができるかを検証した、汎化性をみる実験の結果です。表中のそれぞれのGroupは形状が異なるテスト物体の組み合わせを示しています。
結果としては、ベースラインの手法 (BC-IMP) と比べて同等程度の性能が発揮されていることがわかります。
分析
Gatoの論文では、純粋なタスクの性能評価だけでなく、モデルのスケール則や汎化性、fine-tuningについても検証しています。
Scaling Lawの分析
図10:スケール則の評価(出典:Gatoの論文のFigure 8)
図10の横軸は学習データ数、縦軸はタスクのエキスパートスコアを100としたときのスコアを表しています。学習データ数にかかわらず、モデルサイズを大きくするとスコアが改善されることがわかります。
汎化性、fine-tuningの分析
図11は、様々な事前学習モデルを用意して、事前学習に含まれていないタスクのデータをfew-shotで学習した結果になります。図中の4つのモデルは、
- all data: Gato
- same domain only data: fine-tuning先のタスクを除いた同じドメインのタスクのみで事前学習したモデル
- no control data: 制御タスクのデータセットを除いたデータセットで事前学習したモデル
- scratch: fine-tuning先のタスクのfew-shotデータで位置から学習したモデル
を表しています。
図11:few-shot fine-tuningした時の性能評価(出典:Gatoの論文のFigure 9)
scratch以外の3つの事前学習モデルは、fine-tuning先のタスクでfew-shot学習した結果となっています。
この実験は、どういった事前学習が後段のfine-tuningに有効か検証したものであり、Cartpole swingupやMetaworld assemblyではGato (all data) が事前学習として有効に働いていることを示しています。
ロボット制御タスクにおいても同様のfine-tuningの検証実験を行っています。
図12:ロボットタスクにおける汎化性の評価(出典:Gatoの論文のFigure 10)
図12は、事前学習中にないテストデータをfine-tuningしていった時の成功率をプロットしたグラフになっています。
Gatoを対象テストデータの10エピソード分だけfine-tuningするだけでエキスパートと同程度の性能を発揮することがわかります。
また、Gatoは視覚的な変化に対してもfine-tuningで対応できることが示されており、事前学習段階では、赤いブロックを青いブロックに積み上げるデータ (図13の上段) しか学習していませんが、
青いブロックを緑色のブロックに積み上げるというデータ (図13の下段) を500エピソード分だけ学習することで60%の成功率を達成できています。
500エピソード分のデータだけでBehavior Cloning (BC)したベースラインモデルが0.5%の成功率しか達成できないことを考えると、この結果からGatoの汎化性が垣間見えると思います。
図13:実機ロボットでの汎化性の評価(出典:Gatoの論文のFigure 11)
まとめ
Gatoは、大規模モデルを大規模データセットで事前学習することで、イメージキャプショニングタスクからシミュレーションや実機ロボットにおける制御タスクまで様々なタスクを解くことを可能にしました。
そして、ロボット制御タスクにおいてGatoの事前学習は汎化性においても貢献があることが示されました。
また、Gatoの大事なポイントの一つは、領域ごとに異なるモダリティを扱うために、すべての情報をトークン化し、さらにトークンを埋め込み表現にすることで統一的な系列情報へと変換している点です。
これは、将来より高度なロボット制御を基盤モデルで扱う上でマルチモーダル情報 (聴覚や触覚) を取り入れる必要性に迫られた際に、それら多様な情報を扱うための重要な工夫の一つであると感じます。