概要
以前、Google Deepmindが開発したGraphCastの論文の解説記事を書きました。
本記事では、GraphCastのモデルの中身についてもう少し詳しく解説していきます。
出典は、前回同様こちらの論文からです。
モデルの概要
前回も少し触れましたが、GraphCastのモデル概要について説明し、ざっくりしたイメージをつかみます。
上図のようにGraphCastは大きく3つのプロセスがあり、
- Encoder
- Processor
- Decoder
があります。
入力値は緯度経度グリッド(grid node)上にあり、それをEncoderでグラフのノード(mesh node)にマッピングします。その後Processorでノード間で情報の伝達(message-passing)を行い、Decoderで元の緯度経度グリッドへマッピングします。
これが一連の流れになります。
multi-mesh構造
GraphCastはグラフニューラルネットワークを用いたモデルですが、そのグラフの構造にmuti-meshというものを提案しています。
さっきの図の(g)がmuti-meshの部品を表しています。
$\mathrm{M}^0$は正20面体で、$\mathrm{M}^1$は$\mathrm{M}^0$の各面を4分割してできる80面体です。同様に、$\mathrm{M}^i$は$\mathrm{M}^{i-1}$の各面を4分割してできる立体です。論文では、$\mathrm{M}^6$(81,920面体)まで作っています。
こうしてできた$\mathrm{M}^6$の頂点をノードとします(40,962個)。エッジは、$\mathrm{M}^0,\cdots,\mathrm{M}^6$の各辺とします($\mathrm{M}^6$のエッジだけじゃない)。
このグラフを図示したのがさっきの図の(e)となります。一番細い線が$\mathrm{M}^6$のエッジに対応し、太い矢印は$\mathrm{M}^0$のエッジの一部を示していることがわかります。
緯度経度グリッドと同じようなグラフ構造を持たせるのが最も素直だと思ったのですが、緯度経度グリッドは空間的に不均一であるというデメリットがあります。
下図のように、極ではグリッドが密集している一方、赤道へ近づくにつれて疎になっていきます。
ですが、multi-meshはほぼ均一にノードが分布するので上記のような不釣り合いが起きません。
また、multi-meshの$\mathrm{M}^0$のような粗い構造由来のエッジは長距離の相互依存を伝達でき、$\mathrm{M}^6$のような細かい構造由来のエッジは局所的な相互作用を集計することができます。
CNNでは下図のような畳み込み演算の性質上、局所的な特徴を抽出するのは得意ですが、長距離依存性を抽出するのは困難でした。
また、Transformerでは長距離依存性を考慮できますが、Transformerの計算量は入力系列長の2乗オーダーであり、100万以上あるグリッド点の入力を処理しようと思うと結構ヤバそうです。
最近ではTransformerの計算量問題の緩和が試みられていますが、これらはAttentionをスパース化しているものが多く、実質これはGNNでやっていることと同じではないか、と論文では述べられています。
GraphCastのグラフ構造
この節では、GraphCastのグラフにおけるノードとエッジの定義について述べていきます。
Grid nodes
$\mathcal{V}^\mathrm{G}$は、緯度経度グリッド上のノード(grid node) $v_i^\mathrm{G}$の集合を表します。各ノード$v_i^\mathrm{G}$は以下のような特徴量$\mathbf{v}_i^{\mathrm{G,features}}$を保持しています。
\mathbf{v}_i^{\mathrm{G,features}} = [\mathbf{x}_i^{t-1},\mathbf{x}_i^{t},\mathbf{f}_i^{t-1},\mathbf{f}_i^{t},\mathbf{f}_i^{t+1},\mathbf{c}_i]
$\mathbf{x}_i^t$は、ノード$v_i^\mathrm{G}$位置に対応する時間依存の気象状態$X^t$で、地表変数だけでなく大気37高度のすべての変数を含みます。
$\mathbf{f}_i^t$は、①大気上層における1時間あたりの総入射日射量、②現地時間をdayで表した値(0~1)のsin、③現地時間をdayで表した値(0~1)のcos、④日付をyearで表した値(0~1)のsin、⑤日付をyearで表した値(0~1)のcos、です。これらはすべて計算で求まるものであって、観測量ではありません。
$\mathbf{c}_i$は、①陸か海かの0-1マスク、②地表のジオポテンシャル、③緯度のcos、④経度のsin、⑤経度のcos、です。
以上より、$\mathbf{v}_i^{\mathrm{G,features}}$の次元は
\begin{align}
&(5 \text{ surface variables} + 6\text{ atmospheric variables } \times 37\text{ levels})\times 2 \text{ steps} \\
&+ 5 \text{ forcings} \times 3 \text{ steps}\\
&+ 5 \text{ constant}\\
&= 474 \text{ inputs}
\end{align}
となります。
Mesh nodes
$\mathcal{V}^\mathrm{M}$は、multi-mesh上のノード(mesh node)$v_i^\mathrm{M}$の集合を表します。各ノード$v_i^\mathrm{M}$は以下のような3次元の特徴量$\mathbf{v}_i^{\mathrm{M,features}}$を保持しています。
\mathbf{v}_i^{\mathrm{M,features}} = [\cos(latitude),\sin(longitude),\cos(longitude)]
つまり、meshノードでは位置の情報が保持されています。
Mesh edges
$\mathcal{E}^\mathrm{M}$は、meshノードどうしをつなぐエッジ$e_{v_s^\mathrm{M}\to v_r^\mathrm{M}}^\mathrm{M}$の集合です。
添え字は、meshノード(sender)$v_s^\mathrm{M}$ から meshノード(receiver)$v_r^\mathrm{M}$をつなぐエッジであることを示しています。
各エッジは以下のような4次元の特徴量$\mathbf{e}_{v_s^\mathrm{M}\to v_r^\mathrm{M}}^\mathrm{M,features}$を保持しています。
\mathbf{e}_{v_s^\mathrm{M}\to v_r^\mathrm{M}}^\mathrm{M,features} = [\text{edge length}, \mathbf{d}_{v_s^\mathrm{M}\to v_r^\mathrm{M}}]
$\text{edge length}$はエッジの長さです。
$\mathbf{d}_{v_s^\mathrm{M}\to v_r^\mathrm{M}}$は、receiverノード$v_s^\mathrm{M}$における局所座標のもとで計算した、senderノード$v_s^\mathrm{M}$の位置ベクトルとreceiverノード$v_r^\mathrm{M}$の位置ベクトルの差です。
Grid2Mesh edges
$\mathcal{E}^\mathrm{G2M}$は、gridノード$v_s^\mathrm{G}$からmeshノード$v_r^\mathrm{M}$をつなぐエッジ$e_{v_s^\mathrm{G}\to v_r^\mathrm{M}}^\mathrm{G2M}$の集合です。
gridノード$v_s^\mathrm{G}$からmeshノード$v_r^\mathrm{M}$の距離が、$\mathrm{M}^6$メッシュにおけるエッジの長さの0.6倍以下の場合のみ、両ノードがエッジで繋がれます。
$e_{v_s^\mathrm{G}\to v_r^\mathrm{M}}^\mathrm{G2M}$の特徴量はMesh edgesの場合と同様で、
\mathbf{e}_{v_s^\mathrm{G}\to v_r^\mathrm{M}}^\mathrm{G2M,features} = [\text{edge length}, \mathbf{d}_{v_s^\mathrm{G}\to v_r^\mathrm{M}}]
です。
Mesh2Grid edges
$\mathcal{E}^\mathrm{M2G}$は、meshノード$v_s^\mathrm{M}$からgridノード$v_s^\mathrm{G}$をつなぐエッジ$e_{v_s^\mathrm{M}\to v_r^\mathrm{G}}^\mathrm{M2G}$の集合です。
各gridノードに対し、$\mathrm{M}^6$メッシュの面(三角形)が対応する。つまり、1つのgridノードには3つのmeshノードが対応する(最初の図の(f)がわかりやすい)。
$e_{v_s^\mathrm{M}\to v_r^\mathrm{G}}^\mathrm{M2G}$の特徴量はMesh edgesと同様で、
\mathbf{e}_{v_s^\mathrm{M}\to v_r^\mathrm{G}}^\mathrm{M2G,features} = [\text{edge length}, \mathbf{d}_{v_s^\mathrm{M}\to v_r^\mathrm{G}}]
です。
Encoderの中身
入力特徴量の埋め込み
Encoderでは、上で定義した特徴量
\mathbf{v}_i^{\mathrm{G,features}},
\mathbf{v}_i^{\mathrm{M,features}},
\mathbf{e}_{v_s^\mathrm{M}\to v_r^\mathrm{M}}^\mathrm{M,features},
\mathbf{e}_{v_s^\mathrm{G}\to v_r^\mathrm{M}}^\mathrm{G2M,features},
\mathbf{e}_{v_s^\mathrm{M}\to v_r^\mathrm{G}}^\mathrm{M2G,features}
を、固定次元の潜在空間へ多層パーセプトロン(Multi Linear Perceptron:MLP)で埋め込みます。
これらを埋め込んだ結果をそれぞれ
\mathbf{v}_i^{\mathrm{G}},
\mathbf{v}_i^{\mathrm{M}},
\mathbf{e}_{v_s^\mathrm{M}\to v_r^\mathrm{M}}^\mathrm{M},
\mathbf{e}_{v_s^\mathrm{G}\to v_r^\mathrm{M}}^\mathrm{G2M},
\mathbf{e}_{v_s^\mathrm{M}\to v_r^\mathrm{G}}^\mathrm{M2G}
とします。
Grid2Mesh GNN
Encoderでは次に、大気状態の情報をgridノードからmeshノードに転送するために、gridノードとmeshノードをつなぐGrid2Meshの2部グラフ$\mathcal{G}_\mathrm{G2M}(\mathcal{V}^\mathrm{G}, \mathcal{V}^\mathrm{M}, \mathcal{E}^\mathrm{G2M})$上で1回のmessage-passingを行います。その処理は今から説明します。
まず、Grid2Meshの各エッジ$\mathbf{e}_{v_s^\mathrm{G}\to v_r^\mathrm{M}}^\mathrm{G2M}$は、エッジ両端のノードの情報$\mathbf{v}_s^\mathrm{G},\mathbf{v}_r^\mathrm{M}$を使って更新されます:
その後、各meshノード$\mathbf{v}_i^\mathrm{M}$は、そのmeshノードに向いている全エッジからの情報の総和を使って更新されます:
gridノード$\mathbf{v}_i^\mathrm{G}$の方は、単にMLPへ通すだけです:
これら3つの要素$\mathbf{e}_{v_s^\mathrm{G}\to v_r^\mathrm{M}}^\mathrm{G2M},\mathbf{v}_i^\mathrm{M},\mathbf{v}_i^\mathrm{G}$を更新した後、更新前の値を足し合わせて(残差接続)Encoderの出力結果となります:
Processorの中身
Multi-mesh GNN
まず、両端のノードの情報を使って各meshエッジ$\mathbf{e}_{v_s^\mathrm{M}\to v_r^\mathrm{M}}^\mathrm{M}$を更新します:
その後、各meshノード$\mathbf{v}_i^\mathrm{M}$は、そのmeshノードに向いている全エッジからの情報の総和を使って更新されます:
これら2つの要素$\mathbf{e}_{v_s^\mathrm{M}\to v_r^\mathrm{M}}^\mathrm{M},\mathbf{v}_i^\mathrm{M}$を更新した後、更新前の値を足し合わせて(残差接続)Processorの出力結果となります:
論文では、これらの処理を16回行っています(つまり16層のinteraction network layerを重ねている)。パラメータ共有は行っておらず、全層でパラメータは独立としているそうです。
Decoderの中身
Mesh2Grid GNN
Grid2Mesh GNNと同様に、Mesh2Grid GNNでは2部グラフ$\mathcal{G}_\mathrm{M2G}(\mathcal{V}^\mathrm{G}, \mathcal{V}^\mathrm{M}, \mathcal{E}^\mathrm{M2G})$上で1回のmessage-passingを行います。
Grid2Mesh GNN でやっていることを逆にしただけの処理なので、説明は割愛します。
Output function
ここまで処理してきた情報は、いま$\mathbf{v}_i^\mathrm{G}$たちに集約されています。あとはこれをMLPに通して予測値$\hat{\mathbf{y}}_i$を得ます:
ここでの予測値$\hat{\mathbf{y}}_i$は気温などそのものの値ではなく、前時刻からの差分を表すように学習しています。なので、気温などそのものの値を知りたければ、前時刻での値$X^t$を$\hat{\mathbf{y}}_i$に足せば良いことになります:
最後に
以上がGraphCastのモデルの中身についての詳細解説でした。
一言で言うと、
緯度経度グリッド上の特徴量をmulti-mesh上に射影し、16層のProcessorを通じて情報伝達を行い、再度緯度経度グリッド上に射影して予測値を得る
という仕組みでした。数式をパッと見ると、複雑そうなことをしているように見えますが、添え字がたくさんついているだけでやっていることは意外と単純でした。
…とはいえ、高次元かつ大量のデータでモデルを学習させるにはかなりの計算リソースが必要なわけで、実際に実装するとなるとかなり大変だと思いますが。(32GBのTPU×32個をつかって4週間かかったそう)