DeepMindの「GenCast」:次世代高精度天気予測モデルを試す方法
こんにちは、しゅんです!
今日は、Google DeepMindが公開した次世代の高精度天気予測モデル「GenCast」を取り上げます。このモデルは、生成的AIを駆使し、従来の天気予測技術を超えた性能を発揮します。
ただし、皆さんもご存じの通り、僕はJupyter Notebookが嫌いです(笑)。
そのため、今回もNO NOtebook !!「Notebook形式からPythonスクリプト形式に変換」して実行できるようにしました!
この記事では、モデルの概要からセットアップ方法、そして僕がNotebook形式をあえて嫌う理由についても触れていきます。
GenCastとは?
概要
GenCastは、DeepMindが開発した生成的AIを基盤とする天気予測モデルです。これにより、従来のモデルでは対応が難しかった不確実性を含む気象データの予測を可能にしています。
具体的には以下の特徴があります:
- 最大15日先までの高精度な天気予測
- 台風の進路予測など、大規模な気象現象に対応
- アンサンブルモデルを使用し、複数の予測結果を統合して不確実性を可視化
- 生成的AIを採用し、新しい気象サンプルの生成も可能
実績例
- 2019年に日本を襲った台風19号の進路予測において、上陸7日前から正確な予測を実現しました。
オープンソース
Google DeepMindは2025年12月4日(現地時間)、生成AIベースの高精度天気予測モデル「GenCast」を発表した。出たばかりのモデルです。
DeepMindは、以下のリソースでGenCastをオープンソースとして公開しています:
セットアップ方法
1. 環境の準備
データセットを取得するには、Google Cloud Storage用のツールgsutil
が必要です。もちろん一つずつダウンロードも可能ですが、その場合はいらないかもしれませんがめんどくさいから、gsutil
をインストール’することをすすめします。
GraphCastとGenCastのモデルファイルは、Google Cloud Storage Bucket から入手可能です。以下に、モデルファイルを取得する方法を詳しく説明します。
公式手順はこちら:
2. データセットとモデルのダウンロード
以下のコマンドで、必要なデータセットとモデルを一括ダウンロードできます。
注意事項
- 必要ストレージ:約147GB
- ダウンロード速度に応じて時間がかかる場合があります。
gsutil -m cp -r \
"gs://dm_graphcast/LICENSE" \
"gs://dm_graphcast/dataset" \
"gs://dm_graphcast/gencast" \
"gs://dm_graphcast/graphcast" \
"gs://dm_graphcast/params" \
"gs://dm_graphcast/stats" \
.
3. GenCastのセットアップ
以下のコマンドを順に実行し、環境をセットアップします。
# 作業ディレクトリを作成
mkdir google_deepmind
cd google_deepmind
# リポジトリをクローン
git clone https://github.com/google-deepmind/graphcast.git
cd graphcast
# 必要なパッケージをインストール
pip install -e .
pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip
sudo apt install imagemagick
4. Notebook形式をスクリプト形式に変換
DeepMindの公式リポジトリにはJupyter Notebook形式のデモスクリプトが含まれています。
しかし、僕はNotebookが嫌いです(笑)。視覚的に優れていますが、以下の理由でスクリプト形式に変換しました:
- コードの再利用性が低い
- 実行環境に依存する場合が多い
- デバッグが面倒
そこで、gencast_demo_cloud_vm.ipynb
をまずpython
に出力させて(gencast_demo_cloud_vm.py)、コピーし、GPTに修正を任せました。
code
import dataclasses
import datetime
import math
from typing import Optional
import haiku as hk
import jax
import matplotlib.pyplot as plt
import numpy as np
import xarray
from matplotlib import animation
from graphcast import (
rollout,
xarray_jax,
normalization,
checkpoint,
data_utils,
xarray_tree,
gencast,
denoiser,
nan_cleaning,
)
# プロット関連関数
def select(data: xarray.Dataset, variable: str, level: Optional[int] = None, max_steps: Optional[int] = None):
data = data[variable]
if "batch" in data.dims:
data = data.isel(batch=0)
if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
data = data.isel(time=range(0, max_steps))
if level is not None and "level" in data.coords:
data = data.sel(level=level)
return data
def scale(data: xarray.Dataset, center: Optional[float] = None, robust: bool = False):
vmin = np.nanpercentile(data, (2 if robust else 0))
vmax = np.nanpercentile(data, (98 if robust else 100))
if center is not None:
diff = max(vmax - center, center - vmin)
vmin = center - diff
vmax = center + diff
return data, plt.Normalize(vmin, vmax), ("RdBu_r" if center is not None else "viridis")
def plot_data(data: dict[str, xarray.Dataset], fig_title: str, plot_size: float = 5, robust: bool = False, cols: int = 4):
first_data = next(iter(data.values()))[0]
max_steps = first_data.sizes.get("time", 1)
assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())
cols = min(cols, len(data))
rows = math.ceil(len(data) / cols)
figure = plt.figure(figsize=(plot_size * 2 * cols, plot_size * rows))
figure.suptitle(fig_title, fontsize=16)
figure.subplots_adjust(wspace=0, hspace=0)
figure.tight_layout()
images = []
for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
ax = figure.add_subplot(rows, cols, i + 1)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title)
im = ax.imshow(plot_data.isel(time=0, missing_dims="ignore"), norm=norm, origin="lower", cmap=cmap)
plt.colorbar(mappable=im, ax=ax, orientation="vertical", pad=0.02, aspect=16, shrink=0.75, cmap=cmap)
images.append(im)
def update(frame):
for im, (plot_data, _, _) in zip(images, data.values()):
im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))
ani = animation.FuncAnimation(fig=figure, func=update, frames=max_steps, interval=250)
plt.close(figure.number)
return ani
# モデルとデータのパス
MODEL_PATH = "./gencast/params/GenCast 1p0deg Mini <2019.npz"
DATA_PATH = "./gencast/dataset/source-era5_date-2019-03-29_res-1.0_levels-13_steps-04.nc"
STATS_DIR = "./gencast/stats/"
# モデルをロード
with open(MODEL_PATH, "rb") as f:
ckpt = checkpoint.load(f, gencast.CheckPoint)
params = ckpt.params
state = {}
task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
denoiser_architecture_config = ckpt.denoiser_architecture_config
print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")
# データセットをロード
with open(DATA_PATH, "rb") as f:
example_batch = xarray.load_dataset(f).compute()
# assert example_batch.dims["time"] >= 3
assert example_batch.sizes["time"] >= 3
# 正規化データをロード
with open(STATS_DIR + "diffs_stddev_by_level.nc", "rb") as f:
diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR + "mean_by_level.nc", "rb") as f:
mean_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR + "stddev_by_level.nc", "rb") as f:
stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR + "min_by_level.nc", "rb") as f:
min_by_level = xarray.load_dataset(f).compute()
# GenCast構造を構築
def construct_wrapped_gencast():
predictor = gencast.GenCast(
sampler_config=sampler_config,
task_config=task_config,
denoiser_architecture_config=denoiser_architecture_config,
noise_config=noise_config,
noise_encoder_config=noise_encoder_config,
)
predictor = normalization.InputsAndResiduals(
predictor,
diffs_stddev_by_level=diffs_stddev_by_level,
mean_by_level=mean_by_level,
stddev_by_level=stddev_by_level,
)
predictor = nan_cleaning.NaNCleaner(
predictor=predictor,
reintroduce_nans=True,
fill_value=min_by_level,
var_to_clean="sea_surface_temperature",
)
return predictor
# トレーニングと評価データを抽出
train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
example_batch, target_lead_times=slice("12h", "12h"), **dataclasses.asdict(task_config)
)
eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
# example_batch, target_lead_times=slice("12h", f"{(example_batch.dims['time'] - 2) * 12}h"), **dataclasses.asdict(task_config)
example_batch, target_lead_times=slice("12h", f"{(example_batch.sizes['time'] - 2) * 12}h"), **dataclasses.asdict(task_config)
)
# メイン実行
if __name__ == "__main__":
# プロットの例
variable = "geopotential"
level = 500
# steps = example_batch.dims["time"]
steps = example_batch.sizes["time"]
data = {" ": scale(select(example_batch, variable, level, steps), robust=True)}
ani = plot_data(data, variable, plot_size=7, robust=True)
ani.save("example_animation.gif", writer="imagemagick")
実行方法と結果
スクリプトを実行するには以下のコマンドを使用します:
python3 gencast_demo_cloud_vm.py
実行すると、以下のようなモデル情報が表示されます:
Model description:
GenCast model at lower, 1deg, resolution, with 13 pressure levels and a
4 times refined icosahedral mesh. This model is trained on ERA5 data
from 1979 to 2018, and can be causally evaluated on 2019 and later years.
This model has the smallest memory footprint of those provided and has been provided
to enable low cost demonstrations. It is not representative of GenCast's performance.
Model license:
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
Gif画像は保存されるはず
まとめ:Notebook vs スクリプト形式
Notebook形式には視覚的な魅力がありますが、長期的なプロジェクトやデバッグにやはりはスクリプト形式が適しています。今回、公式のNotebookをもとにPythonスクリプトを作成し、GenCastをローカルで動かすまでを解説しました。
次回の記事では、実際の天気データを用いた予測結果の可視化や性能評価に焦点を当てる予定ですがどうなるのがわからないです。
最後まで読んでいただきありがとうございました!
ぜひGenCastを試してみてください。