2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AIモデルを遊ぶAdvent Calendar 2024

Day 5

DeepMind「GenCast」次世代高精度天気予測モデル2「GraphCast」モデルを利用し全変数の可視化

Last updated at Posted at 2024-12-06

DeepMind「GenCast」次世代高精度天気予測モデル2「GraphCast」モデルを利用し全変数の可視化

こんにちは、しゅんです!

今回は、Google DeepMindの高精度天気予測モデル「GraphCast」で提供されるデータセットを用いて、すべての変数を可視化する方法をご紹介します。このモデルは、膨大な気象データを効率的に処理するために設計されていますが、この記事では、モデルのトレーニングや推論を行うのではなく、提供されたデータの内容を理解するための可視化に焦点を当てます。

僕はJupyter Notebookが嫌いなので(笑)、今回もPythonスクリプト形式で実行しました!この記事を参考に、皆さんもデータセットの内容をぜひ確認してみてください。


GraphCastとは?

概要

GraphCastは、DeepMindが開発したグラフニューラルネットワーク(GNN)を基盤とする天気予測モデルです。このモデルは、膨大な気象データを効率的に処理し、気温、降水量、風速など多様な気象変数を高精度に予測することを目的としています。

具体的な特徴:

  • 最大15日先までの天気予測が可能
  • 台風進路や気象現象の変化をモデル化
  • 多様な気象変数に対応
  • 柔軟性が高く、異なる解像度やデータセットでの利用が可能

今回の目的

この記事では、GraphCastが提供する過去の観測データセット(例: ERA5 データ)の内容を可視化し、各変数がどのような情報を持つかを確認します。
予測モデルの利用やトレーニングはこの記事では扱いません。


データセットの概要

以下は、データセットに含まれる変数とその次元情報の一覧です。

変数名 次元情報 データの種類 説明
geopotential_at_surface (lat, lon) 静的データ 地表での地形高度
land_sea_mask (lat, lon) 静的データ 陸地(1)と海洋(0)の識別
2m_temperature (batch, time, lat, lon) 時間依存データ 地表付近(2m)の気温
mean_sea_level_pressure (batch, time, lat, lon) 時間依存データ 平均海面気圧
10m_v_component_of_wind (batch, time, lat, lon) 時間依存データ 地表付近(10m)の南北方向の風速
10m_u_component_of_wind (batch, time, lat, lon) 時間依存データ 地表付近(10m)の東西方向の風速
total_precipitation_6hr (batch, time, lat, lon) 時間依存データ 6時間ごとの総降水量
toa_incident_solar_radiation (batch, time, lat, lon) 時間依存データ 大気上端での入射太陽放射量
temperature (batch, time, level, lat, lon) 時間依存データ(レベル有り) 特定の大気層での温度
geopotential (batch, time, level, lat, lon) 時間依存データ(レベル有り) 特定の気圧面での地形高度
u_component_of_wind (batch, time, level, lat, lon) 時間依存データ(レベル有り) 特定の気圧面での東西方向の風速
v_component_of_wind (batch, time, level, lat, lon) 時間依存データ(レベル有り) 特定の気圧面での南北方向の風速
vertical_velocity (batch, time, level, lat, lon) 時間依存データ(レベル有り) 特定の気圧面での鉛直速度
specific_humidity (batch, time, level, lat, lon) 時間依存データ(レベル有り) 特定の気圧面での比湿

可視化のコード

以下は、データセットを可視化するコードの全体です。

graphcast_demo_without_train.py
import dataclasses
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax

# データセットのパスを設定
DATASET_PATH = "./graphcast/dataset/source-era5_date-2022-01-01_res-0.25_levels-13_steps-12.nc"

# プロット関数
def select(data, variable, level=None, max_steps=None):
    data = data[variable]
    if "batch" in data.dims:
        data = data.isel(batch=0)
    if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
        data = data.isel(time=range(0, max_steps))
    if level is not None and "level" in data.coords:
        data = data.sel(level=level)
    return data

def scale(data, center=None, robust=False):
    vmin = np.nanpercentile(data, 2 if robust else 0)
    vmax = np.nanpercentile(data, 98 if robust else 100)
    if center is not None:
        diff = max(vmax - center, center - vmin)
        vmin = center - diff
        vmax = center + diff
    return (data, matplotlib.colors.Normalize(vmin, vmax), "RdBu_r" if center is not None else "viridis")

def plot_data(data, title, plot_size=10):
    cols = len(data)
    rows = 1
    fig, axs = plt.subplots(rows, cols, figsize=(plot_size * cols, plot_size))
    if cols == 1:
        axs = [axs]
    for i, (key, (values, norm, cmap)) in enumerate(data.items()):
        ax = axs[i]
        im = ax.imshow(values, norm=norm, origin="lower", cmap=cmap)
        ax.set_title(key)
        fig.colorbar(im, ax=ax)
    plt.tight_layout()
    plt.show()

# データセットをロード
example_batch = xarray.load_dataset(DATASET_PATH).compute()

print("Dataset variables:", example_batch.data_vars.keys())
print("Dataset dimensions:", example_batch.sizes)

# すべてのデータ変数をプロット
for variable in example_batch.data_vars.keys():
    try:
        print(f"Plotting variable: {variable}")
        data = example_batch[variable]

        if "time" in data.dims:
            level = data.coords["level"].values[0] if "level" in data.dims else None
            data_to_plot = select(example_batch, variable, level)
            scaled_data = {f"{variable}": scale(data_to_plot.isel(time=0))}
            plot_data(scaled_data, f"{variable} (time=0)")
        elif "lat" in data.dims and "lon" in data.dims:
            scaled_data = {f"{variable}": scale(data)}
            plot_data(scaled_data, f"{variable} (Static Data)")
        else:
            print(f"Skipping {variable}: Unsupported dimensions {data.dims}")
    except Exception as e:
        print(f"Error while plotting {variable}: {e}")


実行結果

コードを実行すると、データセット内の各変数がプロットされます。以下はその一部です。表を参照:

  1. 2m_temperature: 地表付近の気温
  2. mean_sea_level_pressure: 平均海面気圧
  3. geopotential: 特定の気圧面での地形高度
    Dataset variables: KeysView(Data variables:
        geopotential_at_surface       (lat, lon) float32 4MB 2.735e+04 ... -0.07617
        land_sea_mask                 (lat, lon) float32 4MB 1.0 1.0 1.0 ... 0.0 0.0
        2m_temperature                (batch, time, lat, lon) float32 58MB 250.7 ...
        mean_sea_level_pressure       (batch, time, lat, lon) float32 58MB 9.931e...
        10m_v_component_of_wind       (batch, time, lat, lon) float32 58MB -0.439...
        10m_u_component_of_wind       (batch, time, lat, lon) float32 58MB 1.309 ...
        total_precipitation_6hr       (batch, time, lat, lon) float32 58MB 0.0004...
        toa_incident_solar_radiation  (batch, time, lat, lon) float32 58MB 1.981e...
        temperature                   (batch, time, level, lat, lon) float32 756MB ...
        geopotential                  (batch, time, level, lat, lon) float32 756MB ...
        u_component_of_wind           (batch, time, level, lat, lon) float32 756MB ...
        v_component_of_wind           (batch, time, level, lat, lon) float32 756MB ...
        vertical_velocity             (batch, time, level, lat, lon) float32 756MB ...
        specific_humidity             (batch, time, level, lat, lon) float32 756MB ...)
    Dataset dimensions: Frozen({'lon': 1440, 'lat': 721, 'time': 14, 'level': 13, 'batch': 1})
    Plotting variable: geopotential_at_surface
    Plotting variable: land_sea_mask
    Plotting variable: 2m_temperature
    Plotting variable: mean_sea_level_pressure
    Plotting variable: 10m_v_component_of_wind
    Plotting variable: 10m_u_component_of_wind
    Plotting variable: total_precipitation_6hr
    Plotting variable: toa_incident_solar_radiation
    Plotting variable: temperature
    Plotting variable: geopotential
    Plotting variable: u_component_of_wind
    Plotting variable: v_component_of_wind
    Plotting variable: vertical_velocity
    Plotting variable: specific_humidity
    
    

注意点

  • このコードはデータセット内の観測値や初期条件を可視化することを目的としています。
  • 未来の予測やモデルの推論結果は含まれていません

画像

10m_u_component_of_wind.png

10m_v_component_of_wind.png

geopotential.png

geopotential_at_surface.png

land_sea_mask.png

mean_sea_level_pressure.png

specific_humidity.png

temperature.png

toa_incident_solar_radiation.png

total_precipitation_6hr.png

u_component_of_wind.png

v_component_of_wind.png

vertical_velocity.png

結論

GraphCast データセットには多様な気象変数が含まれており、その可視化を通じて気象現象を直感的に理解することができます。次回は、GraphCast モデルの予測プロセスや推論結果の分析に挑戦する予定です!

最後まで読んでいただきありがとうございました!この記事を参考に、ぜひ GraphCast のデータセットを試してみてください。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?