DeepMind「GenCast」次世代高精度天気予測モデル2「GraphCast」モデルを利用し全変数の可視化
こんにちは、しゅんです!
今回は、Google DeepMindの高精度天気予測モデル「GraphCast」で提供されるデータセットを用いて、すべての変数を可視化する方法をご紹介します。このモデルは、膨大な気象データを効率的に処理するために設計されていますが、この記事では、モデルのトレーニングや推論を行うのではなく、提供されたデータの内容を理解するための可視化に焦点を当てます。
僕はJupyter Notebookが嫌いなので(笑)、今回もPythonスクリプト形式で実行しました!この記事を参考に、皆さんもデータセットの内容をぜひ確認してみてください。
GraphCastとは?
概要
GraphCastは、DeepMindが開発したグラフニューラルネットワーク(GNN)を基盤とする天気予測モデルです。このモデルは、膨大な気象データを効率的に処理し、気温、降水量、風速など多様な気象変数を高精度に予測することを目的としています。
具体的な特徴:
- 最大15日先までの天気予測が可能
- 台風進路や気象現象の変化をモデル化
- 多様な気象変数に対応
- 柔軟性が高く、異なる解像度やデータセットでの利用が可能
今回の目的
この記事では、GraphCastが提供する過去の観測データセット(例: ERA5 データ)の内容を可視化し、各変数がどのような情報を持つかを確認します。
予測モデルの利用やトレーニングはこの記事では扱いません。
データセットの概要
以下は、データセットに含まれる変数とその次元情報の一覧です。
変数名 | 次元情報 | データの種類 | 説明 |
---|---|---|---|
geopotential_at_surface |
(lat, lon) |
静的データ | 地表での地形高度 |
land_sea_mask |
(lat, lon) |
静的データ | 陸地(1)と海洋(0)の識別 |
2m_temperature |
(batch, time, lat, lon) |
時間依存データ | 地表付近(2m)の気温 |
mean_sea_level_pressure |
(batch, time, lat, lon) |
時間依存データ | 平均海面気圧 |
10m_v_component_of_wind |
(batch, time, lat, lon) |
時間依存データ | 地表付近(10m)の南北方向の風速 |
10m_u_component_of_wind |
(batch, time, lat, lon) |
時間依存データ | 地表付近(10m)の東西方向の風速 |
total_precipitation_6hr |
(batch, time, lat, lon) |
時間依存データ | 6時間ごとの総降水量 |
toa_incident_solar_radiation |
(batch, time, lat, lon) |
時間依存データ | 大気上端での入射太陽放射量 |
temperature |
(batch, time, level, lat, lon) |
時間依存データ(レベル有り) | 特定の大気層での温度 |
geopotential |
(batch, time, level, lat, lon) |
時間依存データ(レベル有り) | 特定の気圧面での地形高度 |
u_component_of_wind |
(batch, time, level, lat, lon) |
時間依存データ(レベル有り) | 特定の気圧面での東西方向の風速 |
v_component_of_wind |
(batch, time, level, lat, lon) |
時間依存データ(レベル有り) | 特定の気圧面での南北方向の風速 |
vertical_velocity |
(batch, time, level, lat, lon) |
時間依存データ(レベル有り) | 特定の気圧面での鉛直速度 |
specific_humidity |
(batch, time, level, lat, lon) |
時間依存データ(レベル有り) | 特定の気圧面での比湿 |
可視化のコード
以下は、データセットを可視化するコードの全体です。
import dataclasses
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
# データセットのパスを設定
DATASET_PATH = "./graphcast/dataset/source-era5_date-2022-01-01_res-0.25_levels-13_steps-12.nc"
# プロット関数
def select(data, variable, level=None, max_steps=None):
data = data[variable]
if "batch" in data.dims:
data = data.isel(batch=0)
if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
data = data.isel(time=range(0, max_steps))
if level is not None and "level" in data.coords:
data = data.sel(level=level)
return data
def scale(data, center=None, robust=False):
vmin = np.nanpercentile(data, 2 if robust else 0)
vmax = np.nanpercentile(data, 98 if robust else 100)
if center is not None:
diff = max(vmax - center, center - vmin)
vmin = center - diff
vmax = center + diff
return (data, matplotlib.colors.Normalize(vmin, vmax), "RdBu_r" if center is not None else "viridis")
def plot_data(data, title, plot_size=10):
cols = len(data)
rows = 1
fig, axs = plt.subplots(rows, cols, figsize=(plot_size * cols, plot_size))
if cols == 1:
axs = [axs]
for i, (key, (values, norm, cmap)) in enumerate(data.items()):
ax = axs[i]
im = ax.imshow(values, norm=norm, origin="lower", cmap=cmap)
ax.set_title(key)
fig.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()
# データセットをロード
example_batch = xarray.load_dataset(DATASET_PATH).compute()
print("Dataset variables:", example_batch.data_vars.keys())
print("Dataset dimensions:", example_batch.sizes)
# すべてのデータ変数をプロット
for variable in example_batch.data_vars.keys():
try:
print(f"Plotting variable: {variable}")
data = example_batch[variable]
if "time" in data.dims:
level = data.coords["level"].values[0] if "level" in data.dims else None
data_to_plot = select(example_batch, variable, level)
scaled_data = {f"{variable}": scale(data_to_plot.isel(time=0))}
plot_data(scaled_data, f"{variable} (time=0)")
elif "lat" in data.dims and "lon" in data.dims:
scaled_data = {f"{variable}": scale(data)}
plot_data(scaled_data, f"{variable} (Static Data)")
else:
print(f"Skipping {variable}: Unsupported dimensions {data.dims}")
except Exception as e:
print(f"Error while plotting {variable}: {e}")
実行結果
コードを実行すると、データセット内の各変数がプロットされます。以下はその一部です。表を参照:
-
2m_temperature
: 地表付近の気温 -
mean_sea_level_pressure
: 平均海面気圧 -
geopotential
: 特定の気圧面での地形高度Dataset variables: KeysView(Data variables: geopotential_at_surface (lat, lon) float32 4MB 2.735e+04 ... -0.07617 land_sea_mask (lat, lon) float32 4MB 1.0 1.0 1.0 ... 0.0 0.0 2m_temperature (batch, time, lat, lon) float32 58MB 250.7 ... mean_sea_level_pressure (batch, time, lat, lon) float32 58MB 9.931e... 10m_v_component_of_wind (batch, time, lat, lon) float32 58MB -0.439... 10m_u_component_of_wind (batch, time, lat, lon) float32 58MB 1.309 ... total_precipitation_6hr (batch, time, lat, lon) float32 58MB 0.0004... toa_incident_solar_radiation (batch, time, lat, lon) float32 58MB 1.981e... temperature (batch, time, level, lat, lon) float32 756MB ... geopotential (batch, time, level, lat, lon) float32 756MB ... u_component_of_wind (batch, time, level, lat, lon) float32 756MB ... v_component_of_wind (batch, time, level, lat, lon) float32 756MB ... vertical_velocity (batch, time, level, lat, lon) float32 756MB ... specific_humidity (batch, time, level, lat, lon) float32 756MB ...) Dataset dimensions: Frozen({'lon': 1440, 'lat': 721, 'time': 14, 'level': 13, 'batch': 1}) Plotting variable: geopotential_at_surface Plotting variable: land_sea_mask Plotting variable: 2m_temperature Plotting variable: mean_sea_level_pressure Plotting variable: 10m_v_component_of_wind Plotting variable: 10m_u_component_of_wind Plotting variable: total_precipitation_6hr Plotting variable: toa_incident_solar_radiation Plotting variable: temperature Plotting variable: geopotential Plotting variable: u_component_of_wind Plotting variable: v_component_of_wind Plotting variable: vertical_velocity Plotting variable: specific_humidity
注意点
- このコードはデータセット内の観測値や初期条件を可視化することを目的としています。
- 未来の予測やモデルの推論結果は含まれていません。
画像
結論
GraphCast データセットには多様な気象変数が含まれており、その可視化を通じて気象現象を直感的に理解することができます。次回は、GraphCast モデルの予測プロセスや推論結果の分析に挑戦する予定です!
最後まで読んでいただきありがとうございました!この記事を参考に、ぜひ GraphCast のデータセットを試してみてください。