2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AIモデルを遊ぶAdvent Calendar 2024

Day 4

DeepMindの「GenCast」:次世代高精度天気予測モデルを試す方法(ローカル環境内)1「gencast_demo_cloud_vm」

Last updated at Posted at 2024-12-05

DeepMindの「GenCast」:次世代高精度天気予測モデルを試す方法

こんにちは、しゅんです!

今日は、Google DeepMindが公開した次世代の高精度天気予測モデル「GenCast」を取り上げます。このモデルは、生成的AIを駆使し、従来の天気予測技術を超えた性能を発揮します。
ただし、皆さんもご存じの通り、僕はJupyter Notebookが嫌いです(笑)。
そのため、今回もNO NOtebook !!Notebook形式からPythonスクリプト形式に変換」して実行できるようにしました!

この記事では、モデルの概要からセットアップ方法、そして僕がNotebook形式をあえて嫌う理由についても触れていきます。

GenCastとは?

概要

GenCastは、DeepMindが開発した生成的AIを基盤とする天気予測モデルです。これにより、従来のモデルでは対応が難しかった不確実性を含む気象データの予測を可能にしています。

具体的には以下の特徴があります:

  • 最大15日先までの高精度な天気予測
  • 台風の進路予測など、大規模な気象現象に対応
  • アンサンブルモデルを使用し、複数の予測結果を統合して不確実性を可視化
  • 生成的AIを採用し、新しい気象サンプルの生成も可能

実績例

  • 2019年に日本を襲った台風19号の進路予測において、上陸7日前から正確な予測を実現しました。

オープンソース

Google DeepMindは2025年12月4日(現地時間)、生成AIベースの高精度天気予測モデル「GenCast」を発表した。出たばかりのモデルです。
DeepMindは、以下のリソースでGenCastをオープンソースとして公開しています:

セットアップ方法

1. 環境の準備

データセットを取得するには、Google Cloud Storage用のツールgsutilが必要です。もちろん一つずつダウンロードも可能ですが、その場合はいらないかもしれませんがめんどくさいから、gsutilをインストール’することをすすめします。
GraphCastとGenCastのモデルファイルは、Google Cloud Storage Bucket から入手可能です。以下に、モデルファイルを取得する方法を詳しく説明します。

公式手順はこちら:

2. データセットとモデルのダウンロード

以下のコマンドで、必要なデータセットとモデルを一括ダウンロードできます。

注意事項

  • 必要ストレージ:約147GB
  • ダウンロード速度に応じて時間がかかる場合があります。
gsutil -m cp -r \
  "gs://dm_graphcast/LICENSE" \
  "gs://dm_graphcast/dataset" \
  "gs://dm_graphcast/gencast" \
  "gs://dm_graphcast/graphcast" \
  "gs://dm_graphcast/params" \
  "gs://dm_graphcast/stats" \
  .

3. GenCastのセットアップ

以下のコマンドを順に実行し、環境をセットアップします。

# 作業ディレクトリを作成
mkdir google_deepmind
cd google_deepmind

# リポジトリをクローン
git clone https://github.com/google-deepmind/graphcast.git
cd graphcast

# 必要なパッケージをインストール
pip install -e .
pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip
sudo apt install imagemagick

4. Notebook形式をスクリプト形式に変換

DeepMindの公式リポジトリにはJupyter Notebook形式のデモスクリプトが含まれています。
しかし、僕はNotebookが嫌いです(笑)。視覚的に優れていますが、以下の理由でスクリプト形式に変換しました:

  • コードの再利用性が低い
  • 実行環境に依存する場合が多い
  • デバッグが面倒

そこで、gencast_demo_cloud_vm.ipynbをまずpythonに出力させて(gencast_demo_cloud_vm.py)、コピーし、GPTに修正を任せました。

code

gencast_demo_cloud_vm.py
import dataclasses
import datetime
import math
from typing import Optional

import haiku as hk
import jax
import matplotlib.pyplot as plt
import numpy as np
import xarray
from matplotlib import animation
from graphcast import (
    rollout,
    xarray_jax,
    normalization,
    checkpoint,
    data_utils,
    xarray_tree,
    gencast,
    denoiser,
    nan_cleaning,
)

# プロット関連関数
def select(data: xarray.Dataset, variable: str, level: Optional[int] = None, max_steps: Optional[int] = None):
    data = data[variable]
    if "batch" in data.dims:
        data = data.isel(batch=0)
    if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
        data = data.isel(time=range(0, max_steps))
    if level is not None and "level" in data.coords:
        data = data.sel(level=level)
    return data

def scale(data: xarray.Dataset, center: Optional[float] = None, robust: bool = False):
    vmin = np.nanpercentile(data, (2 if robust else 0))
    vmax = np.nanpercentile(data, (98 if robust else 100))
    if center is not None:
        diff = max(vmax - center, center - vmin)
        vmin = center - diff
        vmax = center + diff
    return data, plt.Normalize(vmin, vmax), ("RdBu_r" if center is not None else "viridis")

def plot_data(data: dict[str, xarray.Dataset], fig_title: str, plot_size: float = 5, robust: bool = False, cols: int = 4):
    first_data = next(iter(data.values()))[0]
    max_steps = first_data.sizes.get("time", 1)
    assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

    cols = min(cols, len(data))
    rows = math.ceil(len(data) / cols)
    figure = plt.figure(figsize=(plot_size * 2 * cols, plot_size * rows))
    figure.suptitle(fig_title, fontsize=16)
    figure.subplots_adjust(wspace=0, hspace=0)
    figure.tight_layout()

    images = []
    for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
        ax = figure.add_subplot(rows, cols, i + 1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(title)
        im = ax.imshow(plot_data.isel(time=0, missing_dims="ignore"), norm=norm, origin="lower", cmap=cmap)
        plt.colorbar(mappable=im, ax=ax, orientation="vertical", pad=0.02, aspect=16, shrink=0.75, cmap=cmap)
        images.append(im)

    def update(frame):
        for im, (plot_data, _, _) in zip(images, data.values()):
            im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

    ani = animation.FuncAnimation(fig=figure, func=update, frames=max_steps, interval=250)
    plt.close(figure.number)
    return ani

# モデルとデータのパス
MODEL_PATH = "./gencast/params/GenCast 1p0deg Mini <2019.npz"
DATA_PATH = "./gencast/dataset/source-era5_date-2019-03-29_res-1.0_levels-13_steps-04.nc"
STATS_DIR = "./gencast/stats/"

# モデルをロード
with open(MODEL_PATH, "rb") as f:
    ckpt = checkpoint.load(f, gencast.CheckPoint)
params = ckpt.params
state = {}

task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
denoiser_architecture_config = ckpt.denoiser_architecture_config
print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")

# データセットをロード
with open(DATA_PATH, "rb") as f:
    example_batch = xarray.load_dataset(f).compute()

# assert example_batch.dims["time"] >= 3
assert example_batch.sizes["time"] >= 3
# 正規化データをロード
with open(STATS_DIR + "diffs_stddev_by_level.nc", "rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR + "mean_by_level.nc", "rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR + "stddev_by_level.nc", "rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR + "min_by_level.nc", "rb") as f:
    min_by_level = xarray.load_dataset(f).compute()

# GenCast構造を構築
def construct_wrapped_gencast():
    predictor = gencast.GenCast(
        sampler_config=sampler_config,
        task_config=task_config,
        denoiser_architecture_config=denoiser_architecture_config,
        noise_config=noise_config,
        noise_encoder_config=noise_encoder_config,
    )
    predictor = normalization.InputsAndResiduals(
        predictor,
        diffs_stddev_by_level=diffs_stddev_by_level,
        mean_by_level=mean_by_level,
        stddev_by_level=stddev_by_level,
    )
    predictor = nan_cleaning.NaNCleaner(
        predictor=predictor,
        reintroduce_nans=True,
        fill_value=min_by_level,
        var_to_clean="sea_surface_temperature",
    )
    return predictor

# トレーニングと評価データを抽出
train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("12h", "12h"), **dataclasses.asdict(task_config)
)

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    # example_batch, target_lead_times=slice("12h", f"{(example_batch.dims['time'] - 2) * 12}h"), **dataclasses.asdict(task_config)
    example_batch, target_lead_times=slice("12h", f"{(example_batch.sizes['time'] - 2) * 12}h"), **dataclasses.asdict(task_config)
)

# メイン実行
if __name__ == "__main__":
    # プロットの例
    variable = "geopotential"
    level = 500
    # steps = example_batch.dims["time"]
    steps = example_batch.sizes["time"]
    data = {" ": scale(select(example_batch, variable, level, steps), robust=True)}
    ani = plot_data(data, variable, plot_size=7, robust=True)
    ani.save("example_animation.gif", writer="imagemagick")

実行方法と結果

スクリプトを実行するには以下のコマンドを使用します:

python3 gencast_demo_cloud_vm.py

実行すると、以下のようなモデル情報が表示されます:

Model description:
 
        GenCast model at lower, 1deg, resolution, with 13 pressure levels and a
        4 times refined icosahedral mesh. This model is trained on ERA5 data
        from 1979 to 2018, and can be causally evaluated on 2019 and later years.
        This model has the smallest memory footprint of those provided and has been provided
        to enable low cost demonstrations. It is not representative of GenCast's performance.
         

Model license:
 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
 

Gif画像は保存されるはず

example_animation.gif

まとめ:Notebook vs スクリプト形式

Notebook形式には視覚的な魅力がありますが、長期的なプロジェクトやデバッグにやはりはスクリプト形式が適しています。今回、公式のNotebookをもとにPythonスクリプトを作成し、GenCastをローカルで動かすまでを解説しました。

次回の記事では、実際の天気データを用いた予測結果の可視化性能評価に焦点を当てる予定ですがどうなるのがわからないです。

最後まで読んでいただきありがとうございました!
ぜひGenCastを試してみてください。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?