0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AI気象モデル推論ツール Earth2Studio の使い方

0
Last updated at Posted at 2025-12-31

はじめに

AI気象モデルの推論には、重みのダウンロードや入力データの形式の変換が必要である。複数のモデルの比較する際に、個別のモデルごとに入出力データの形式を整えるのは手間であるため、ラッパーライブラリが存在する。

初期の頃はECMWFの作成したai-modelsというライブラリが使われていたが、現在はArchive扱いとなりこれ以上メンテナンスはされないようである。

最近活発に整備されているのがNVIDIAのEarth2Studioというライブラリである。本記事ではEarth2Studioを使ったAI気象モデルの推論についてまとめる。

ここではLinuxのCUDA環境を前提としている。ROCm環境での実行については以下の記事を参照のこと。

FourCastNet v2 (SFNO) を実行する

インストールは、適当な実行用のディレクトリを作成し移動したのち、uvを用いて次のような手順で行う。

# プロジェクトの初期化と仮想環境の作成
uv init --python=3.12
uv sync

# RTX 5090 (Blackwell世代) に対応したCUDA12.8のPyTorchを入れる
uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128 

# --extra でインストールしたいモデル名(sfno)を指定する
uv add "earth2studio @ git+https://github.com/NVIDIA/earth2studio.git@0.11.0" --extra sfno 

# 色々な初期値データを使用するために必要
uv add earth2studio --extra data

sfno (Spherical Fourier Operator Network: SFNO) が今回インストールしたいFourCastNet v2(ai-modelsでいうfourcastnetv2)であるのに注意。fcnはFourCastNetのv1 (ai-modelsでいうfourcastnet)であり、長期間の予報(ロールアウト)によって場が崩壊する。

最小限の実行可能コードは次のとおり。

import os

from earth2studio.data import IFS
from earth2studio.io import NetCDF4Backend
from earth2studio.models.px import SFNO
import earth2studio.run as run

# SFNO(予報モデル)のロード
package = SFNO.load_default_package()
model = SFNO.load_model(package)

# データソースの作成(IFSデータ)
data = IFS()

# 出力先(NetCDF4バックエンド)
os.makedirs("outputs", exist_ok=True)
io = NetCDF4Backend("outputs/deterministic_result.nc", backend_kwargs={"mode": "w"})

# 予報開始時刻と予報ステップ数を指定
forecast_time = "2025-01-01"
nsteps = 40  # 6時間間隔 × 40ステップ = 240時間予報

run.deterministic([forecast_time], nsteps, model, data, io)
  • 重みデータは初回にダウンロードされ、2回目以降の実行ではダウンロード済みのものが読み込まれる。重みデータのダウンロードはデフォルトでは300秒(5分)でタイムアウトするため、時間がかかる場合はEARTH2STUDIO_PACKAGE_TIMEOUT=900等で設定を変更する。
  • データソースとしてはIFSの他にGFSARCO(ERA5)などが選択できる
  • 出力先(IO)のNetCDF4Backendでは、逐次でファイルに書きこめる代わりに、圧縮やチャンクの設定をすることができない。圧縮したい場合は、一旦XarrayBackendでxarray.Datasetを経由して.to_netcdf()するとよい(後述)。ただし出力は一旦メモリ上に蓄積されるのに注意。

Precipitation AFNO v2 で降水を診断する

FourCastNet v2自体には降水が出力変数として用意されていないが、出力された気象場からPrecipitation AFNOというモデルを用いて診断することができる。この一連のパイプライン処理はrun.diagnosticで行うことができる。

Precipitaiton AFNO v2のインストールは次のとおり。

uv add earth2studio --extra precip-afno-v2

最小限の実行可能コードは次のとおり。

import os

from earth2studio.data import IFS
from earth2studio.io import NetCDF4Backend
from earth2studio.models.px import SFNO
from earth2studio.models.dx import PrecipitationAFNOv2
import earth2studio.run as run


# 出力ディレクトリの作成
os.makedirs("outputs", exist_ok=True)

# 予報モデル(SFNO)のロード
package = SFNO.load_default_package()
prognostic_model = SFNO.load_model(package)

# 診断モデル(PrecipitationAFNOv2)のロード
package = PrecipitationAFNOv2.load_default_package()
diagnostic_model = PrecipitationAFNOv2.load_model(package)

# データソースの作成(IFSデータ)
data = IFS()

# 出力先(NetCDF4バックエンド)
io = NetCDF4Backend("outputs/diagnostic_result.nc", backend_kwargs={"mode": "w"})

# 予報開始時刻と予報ステップ数を指定
forecast_time = "2025-01-01"
nsteps = 40  # 6時間間隔 × 40ステップ = 240時間予報

io = run.diagnostic([forecast_time], nsteps, prognostic_model, diagnostic_model, data, io)

tp06という変数名で6時間降水量が m 単位で出力される。また、緯度座標は720点で0.25°格子のうち南極点が存在しない点が、予報モデルのFourCastNetとは異なる。

run.diagnosticを使用する場合、元のprognostic_model (FourCastNet) の変数をNetCDFに出力することはできない。これを行いたい場合については次の節で扱う。

FourCastNet v2 + Precipitation AFNO v2 の結果をまとめて出力する

予報モデルと診断モデルの結果の両方を保存したい場合は、以下のrun.diagnosticのコードを参考に、カスタムワークフローを定義する。

ポイントはループの中で予報モデルと診断モデルの出力結果に対してio.write()を呼ぶことである。

with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar:
    for step, (x, coords) in enumerate(model):
        # prognosticの出力を指定変数のみに絞り込んで書き込む
        x_p, coords_p = map_coords(x, coords, output_coords)
        io.write(*split_coords(x_p, coords_p))

        # diagnosticの出力を書き込む
        x_d, coords_d = map_coords(x, coords, diagnostic_ic)
        x_d, coords_d = diagnostic_model(x_d, coords_d)
        io.write(*split_coords(x_d, coords_d))

        pbar.update(1)
        if step == nsteps:
            break

ソースコード全体は長いので折りたたんでおく。上記以外に、南極点の降水量の周囲の値での穴埋め、出力変数のフィルタや圧縮を行なっている。

ソースコード全体
import os
from collections import OrderedDict

import numpy as np
import torch

from tqdm import tqdm

from earth2studio.data import IFS, fetch_data
from earth2studio.io import XarrayBackend
from earth2studio.models.dx import PrecipitationAFNOv2
from earth2studio.models.px import SFNO
from earth2studio.utils.coords import map_coords, split_coords
from earth2studio.utils.time import to_time_array


def main():
    # 出力ディレクトリの作成
    os.makedirs("outputs", exist_ok=True)

    # FourCastNet(予報モデル)のロード
    print("FourCastNetモデルをロード中...")
    fcn_package = SFNO.load_default_package()
    prognostic_model = SFNO.load_model(fcn_package)

    # PrecipitationAFNOv2(診断モデル)のロード
    print("PrecipitationAFNOv2モデルをロード中...")
    precip_package = PrecipitationAFNOv2.load_default_package()
    diagnostic_model = PrecipitationAFNOv2.load_model(precip_package)

    # データソースの作成(IFSデータ)
    data = IFS()

    # 出力先(Xarrayバックエンド)
    io = XarrayBackend()

    # 予報開始時刻と予報ステップ数を指定
    forecast_time = "2025-01-01"
    nsteps = 40  # 6時間間隔 × 40ステップ = 240時間予報

    # デバイスの設定
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Inference device: {device}")
    prognostic_model = prognostic_model.to(device)
    diagnostic_model = diagnostic_model.to(device)

    # 入力座標の取得
    prognostic_ic = prognostic_model.input_coords()
    diagnostic_ic = diagnostic_model.input_coords()
    time_arr = to_time_array([forecast_time])

    # 補間設定
    if hasattr(prognostic_model, "interp_method"):
        interp_to = prognostic_ic
        interp_method = prognostic_model.interp_method
    else:
        interp_to = None
        interp_method = "nearest"

    # データ取得(1回だけ)
    x0, coords0 = fetch_data(
        source=data,
        time=time_arr,
        variable=prognostic_ic["variable"],
        lead_time=prognostic_ic["lead_time"],
        device=device,
        interp_to=interp_to,
        interp_method=interp_method,
    )

    # 出力変数の制限(必要な変数のみを出力)
    output_coords = {
        "variable": np.array(
            [
                "z500",
                "u200",
                "u500",
                "u850",
                "v200",
                "v500",
                "v850",
                "t500",
                "t700",
                "t850",
                "q500",
                "q700",
                "q850",
                "msl",
                "u10m",
                "v10m",
                "t2m",
                "tcwv",
            ]
        )
    }

    # IO座標のセットアップ(time/lead_timeは全ステップ分を用意)
    prognostic_oc = OrderedDict(prognostic_model.output_coords(prognostic_ic)).copy()
    diagnostic_oc = OrderedDict(diagnostic_model.output_coords(diagnostic_ic))

    # output_coordsで指定された変数でprognostic_ocを上書き
    for key, value in prognostic_oc.items():
        prognostic_oc[key] = output_coords.get(key, value)

    # バッチなどのダミー次元を除去(run.deterministic と同様)
    for key, value in list(prognostic_oc.items()):
        if value.shape == (0,):
            del prognostic_oc[key]

    # lead_time を全ステップ分に展開(run.deterministic と同様)
    dt = prognostic_oc["lead_time"]
    prognostic_oc["time"] = time_arr
    prognostic_oc["lead_time"] = np.asarray(
        [dt * i for i in range(nsteps + 1)]
    ).flatten()
    prognostic_oc.move_to_end("lead_time", last=False)
    prognostic_oc.move_to_end("time", last=False)

    # 保存する変数名を union する(衝突した場合はFCN側を優先)
    p_vars = list(prognostic_oc["variable"])
    d_vars = list(diagnostic_oc["variable"])
    var_names = p_vars + [v for v in d_vars if v not in p_vars]

    coords_for_io = prognostic_oc.copy()
    coords_for_io.pop("variable")
    io.add_array(coords_for_io, var_names)

    # Map lat/lon if needed
    x0, coords0 = map_coords(x0, coords0, prognostic_ic)

    # 予報を1回だけ回し、同じ状態からdiagnosticを計算して書き込む
    model = prognostic_model.create_iterator(x0, coords0)
    with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar:
        for step, (x, coords) in enumerate(model):
            # prognosticの出力を指定変数のみに絞り込んで書き込む
            x_p, coords_p = map_coords(x, coords, output_coords)
            io.write(*split_coords(x_p, coords_p))

            # diagnosticの出力を書き込む
            x_d, coords_d = map_coords(x, coords, diagnostic_ic)
            x_d, coords_d = diagnostic_model(x_d, coords_d)
            io.write(*split_coords(x_d, coords_d))

            pbar.update(1)
            if step == nsteps:
                break

    # 南極点を周囲の値で穴埋めする
    ds = io.root
    ds["tp06"][:, :, -1, :] = ds["tp06"][:, :, -2, :].mean(dim="lon")

    encoding_common = {"zlib": True, "complevel": 5, "shuffle": True, "dtype": "int16"}
    g = 9.80665
    encoding_z500 = {**encoding_common, "scale_factor": g, "add_offset": 5000 * g}
    encoding_msl = {**encoding_common, "scale_factor": 10, "add_offset": 100000}
    encoding_wind = {**encoding_common, "scale_factor": 0.1}
    encoding_temp = {**encoding_common, "scale_factor": 0.1, "add_offset": 273.15}
    encoding_q500 = {**encoding_common, "scale_factor": 5e-6}
    encoding_q700 = {**encoding_common, "scale_factor": 1e-5}
    encoding_q850 = {**encoding_common, "scale_factor": 2e-5}
    encoding_tcwv = {**encoding_common, "scale_factor": 0.1}
    encoding_tp06 = {**encoding_common, "scale_factor": 1e-4}
    ds.isel(lead_time=slice(1, 41)).to_netcdf(
        "outputs/result.nc",
        encoding={
            "z500": encoding_z500,
            "u200": encoding_wind,
            "u500": encoding_wind,
            "u850": encoding_wind,
            "v200": encoding_wind,
            "v500": encoding_wind,
            "v850": encoding_wind,
            "t500": encoding_temp,
            "t700": encoding_temp,
            "t850": encoding_temp,
            "q500": encoding_q500,
            "q700": encoding_q700,
            "q850": encoding_q850,
            "msl": encoding_msl,
            "u10m": encoding_wind,
            "v10m": encoding_wind,
            "t2m": encoding_temp,
            "tcwv": encoding_tcwv,
            "tp06": encoding_tp06,
        },
    )

    ds.close()


if __name__ == "__main__":
    main()

熱帯低気圧(TC)のトラッキングをする

熱帯低気圧(TC)のトラッキングをすることができる。なおCuPyに依存しているため、ROCm環境ではそのままでは動かない。

以下に上記コードからの差分を記述する。

まず追加パッケージのインポートを行う。Polarsは最後に表データに変換するために使っている。

import polars as pl
from earth2studio.models.dx import TCTrackerWuDuan, TCTrackerVitart

次にモデルをロードする。

# TCトラッカーのロード("wuduan" または "vitart" を選択)
tracker_type = "wuduan"  # or "vitart"
print(f"TCトラッカー ({tracker_type}) をロード中...")
if tracker_type == "wuduan":
    tracker = TCTrackerWuDuan().to(device)
else:
    tracker = TCTrackerVitart().to(device)

推論ループ内で、トラッキングを実行する。tc_outputには過去のトラッキングの情報が蓄積されていく。

with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar:
    for step, (x, coords) in enumerate(model):
        # prognosticの出力を指定変数のみに絞り込んで書き込む
        x_p, coords_p = map_coords(x, coords, output_coords)
        io.write(*split_coords(x_p, coords_p))

        # diagnosticの出力を書き込む
        x_d, coords_d = map_coords(x, coords, diagnostic_ic)
        x_d, coords_d = diagnostic_model(x_d, coords_d)
        io.write(*split_coords(x_d, coords_d))

        # TCトラッキング実行
        x_t, coords_t = map_coords(x, coords, tracker.input_coords())
        tc_output, tc_coords = tracker(x_t, coords_t)
        # lead_time次元を削除
        tc_output = tc_output[:, 0]

        pbar.update(1)
        if step == nsteps:
            break

ループの外側でtc_outputからトラッキングの結果を取り出す。元の出力はtorch.Tensorであるため、扱いやすいようにPolars経由でParquetに変換している。出力されるカラムはこの場合、batch, path_id, step, tc_lat, tc_lon, tc_msl, tc_w10mである。この部分はお好みで好きな形式にするとよいだろう。

# TCトラッキング結果をDataFrameに変換して保存
tc_data = tc_output.cpu().numpy()
tc_var_names = list(tc_coords["variable"])

# [batch, path, step, variable] -> フラットなDataFrameに変換
rows = []
for b in range(tc_data.shape[0]):
    for p in range(tc_data.shape[1]):
        for s in range(tc_data.shape[2]):
            row = {"batch": b, "path_id": p, "step": s}
            for v, var_name in enumerate(tc_var_names):
                row[var_name] = tc_data[b, p, s, v]
            rows.append(row)

df = pl.DataFrame(rows)
# NaN行を除去(トラッキングがない部分)
df = df.drop_nans(subset=tc_var_names)
df.write_parquet("outputs/tc_tracks.parquet")
print(
    f"TCトラッキング結果を outputs/tc_tracks.parquet に保存しました({len(df)} rows)"
)

FuXiを実行する

FuXiは他のモデルと以下の点が異なるため、run.deterministicでそのまま推論することができない。

  1. 湿度が比湿ではなく相対湿度である
  2. 入力変数に降水量が含まれている

このうち1.については、model.dx.DerivedRHという気温と比湿から相対湿度を作成する診断モデルを利用することができる。

また2.については、値0を入力することで実行できる。これがFuXiでの正しい取り扱いであるかは不明だが、後継バージョンのFuXi2ではThe accumulated variables ('ssr', 'ssrd', 'fdir', 'ttr', 'tp') are not needed for input and can be set to zero. とあり、診断変数の初期値は0としている。

実際のコードは長いため以下に折りたたんで示す。

コード全体
import os
from collections import OrderedDict
from datetime import datetime

import numpy as np
import torch

from earth2studio.data import GFS, fetch_data
from earth2studio.io import XarrayBackend
from earth2studio.models.dx.derived import DerivedRH
from earth2studio.models.px import FuXi
from earth2studio.utils.coords import map_coords, split_coords
from earth2studio.utils.time import to_time_array

# FuXiの圧力レベル
FUXI_PRESSURE_LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]

# 設定
forecast_time = "2025-01-01T00:00:00"  # 予報開始時刻
nsteps = 4  # 予報ステップ数(6時間×4=24時間)
os.makedirs("output", exist_ok=True)

# デバイス設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# FuXiモデルのロード
print("FuXiモデルをロード中...")
fuxi_package = FuXi.load_default_package()
model = FuXi.load_model(fuxi_package, cascade_models=False).to(device)

# DerivedRH(比湿→相対湿度変換)
rh_model = DerivedRH(levels=FUXI_PRESSURE_LEVELS).to(device)

# GFSデータソース
data = GFS()

# FuXi入力座標を取得
input_coords = model.input_coords()
time_arr = to_time_array([forecast_time])

# 取得変数リスト作成(r* → q*、tp06は除外)
fetch_variables = []
for var in input_coords["variable"]:
    if var.startswith("r") and var[1:].isdigit():
        fetch_variables.append(f"q{var[1:]}")
    elif var == "tp06":
        continue  # 降水量は後で0を設定
    else:
        fetch_variables.append(var)

# 補間設定
interp_to = OrderedDict({"_lat": input_coords["lat"], "_lon": input_coords["lon"]})

# データ取得
print("GFSデータを取得中...")
x_fetch, coords_fetch = fetch_data(
    source=data,
    time=time_arr,
    variable=fetch_variables,
    lead_time=input_coords["lead_time"],
    device=device,
    interp_to=interp_to,
    interp_method="linear",
)

# 相対湿度を計算
print("相対湿度を計算中...")
rh_input_vars = []
for level in FUXI_PRESSURE_LEVELS:
    rh_input_vars.extend([f"t{level}", f"q{level}"])

x_rh_in, coords_rh_in = map_coords(
    x_fetch, coords_fetch, {"variable": np.array(rh_input_vars)}
)
x_rh, coords_rh = rh_model(x_rh_in, coords_rh_in)

# FuXi入力を構築
print("FuXi入力を構築中...")
var_dict = {
    var: x_fetch[..., i : i + 1, :, :]
    for i, var in enumerate(coords_fetch["variable"])
}

# 相対湿度を追加
for i, level in enumerate(FUXI_PRESSURE_LEVELS):
    var_dict[f"r{level}"] = x_rh[..., i : i + 1, :, :]

# tp06に0を設定
tp06_shape = list(x_fetch.shape)
tp06_shape[-3] = 1
var_dict["tp06"] = torch.zeros(tp06_shape, dtype=x_fetch.dtype, device=device)

# 変数順に結合
x0 = torch.cat([var_dict[var] for var in input_coords["variable"]], dim=-3)
coords0 = OrderedDict(
    {
        "time": time_arr,
        "lead_time": input_coords["lead_time"],
        "variable": input_coords["variable"],
        "lat": input_coords["lat"],
        "lon": input_coords["lon"],
    }
)

# 出力設定
io = XarrayBackend()
output_coords = OrderedDict(model.output_coords(input_coords)).copy()
for key, value in list(output_coords.items()):
    if value.shape == (0,):
        del output_coords[key]

dt = output_coords["lead_time"]
output_coords["time"] = time_arr
output_coords["lead_time"] = np.asarray(
    [dt * i for i in range(nsteps + 1)]
).flatten()
output_coords.move_to_end("lead_time", last=False)
output_coords.move_to_end("time", last=False)

var_names = list(output_coords["variable"])
io_coords = output_coords.copy()
io_coords.pop("variable")
io.add_array(io_coords, var_names)

# 推論実行
print(f"推論実行中({nsteps}ステップ)...")
iterator = model.create_iterator(x0, coords0)
for step, (x, coords) in enumerate(iterator):
    io.write(*split_coords(x, coords))
    print(f"  Step {step}/{nsteps} 完了")
    if step == nsteps:
        break

# 結果保存
output_file = "output/simple_fuxi_gfs.nc"
io.root.isel(lead_time=slice(1, nsteps + 1)).to_netcdf(output_file)
print(f"結果を {output_file} に保存しました。")
io.root.close()

AIFS-ENSを実行する

AIFS-ENSのようなアンサンブルモデルを実行する場合、run.ensembleワークフローを使用する。

import os
import numpy as np
from earth2studio.models.px import AIFSENS
from earth2studio.data import IFS
from earth2studio.io import NetCDF4Backend
from earth2studio.perturbation import Zero
from earth2studio.run import ensemble as run


# VRAM使用量の削減
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["ANEMOI_INFERENCE_NUM_CHUNKS"] = "16"

model = AIFSENS.load_model(AIFSENS.load_default_package())
out_vars = ["u10m", "v10m", "t2m", "msl", "tcw"]
data = IFS()
io = NetCDF4Backend("aifs-ens.nc", backend_kwargs={"mode": "w"})
perturbation = Zero()
run(
    time=["2025-01-01T00:00:00"],
    nsteps=40,
    nensemble=4,
    prognostic=model,
    data=data,
    io=io,
    perturbation=perturbation,
    batch_size=1,
    output_coords={"variable": np.array(out_vars)},
)

モデルの重みデータのキャッシュ先

モデルの重みデータは~/.cache/earth2studio/下にモデル別にダウンロードされキャッシュされる。ダウンロードに時間がかかるため、同じネットワーク内の別の計算機で推論する場合にはコピーして使うとよい。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?