はじめに
AI気象モデルの推論には、重みのダウンロードや入力データの形式の変換が必要である。複数のモデルの比較する際に、個別のモデルごとに入出力データの形式を整えるのは手間であるため、ラッパーライブラリが存在する。
初期の頃はECMWFの作成したai-modelsというライブラリが使われていたが、現在はArchive扱いとなりこれ以上メンテナンスはされないようである。
最近活発に整備されているのがNVIDIAのEarth2Studioというライブラリである。本記事ではEarth2Studioを使ったAI気象モデルの推論についてまとめる。
ここではLinuxのCUDA環境を前提としている。ROCm環境での実行については以下の記事を参照のこと。
FourCastNet v2 (SFNO) を実行する
インストールは、適当な実行用のディレクトリを作成し移動したのち、uvを用いて次のような手順で行う。
# プロジェクトの初期化と仮想環境の作成
uv init --python=3.12
uv sync
# RTX 5090 (Blackwell世代) に対応したCUDA12.8のPyTorchを入れる
uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
# --extra でインストールしたいモデル名(sfno)を指定する
uv add "earth2studio @ git+https://github.com/NVIDIA/earth2studio.git@0.11.0" --extra sfno
# 色々な初期値データを使用するために必要
uv add earth2studio --extra data
sfno (Spherical Fourier Operator Network: SFNO) が今回インストールしたいFourCastNet v2(ai-modelsでいうfourcastnetv2)であるのに注意。fcnはFourCastNetのv1 (ai-modelsでいうfourcastnet)であり、長期間の予報(ロールアウト)によって場が崩壊する。
最小限の実行可能コードは次のとおり。
import os
from earth2studio.data import IFS
from earth2studio.io import NetCDF4Backend
from earth2studio.models.px import SFNO
import earth2studio.run as run
# SFNO(予報モデル)のロード
package = SFNO.load_default_package()
model = SFNO.load_model(package)
# データソースの作成(IFSデータ)
data = IFS()
# 出力先(NetCDF4バックエンド)
os.makedirs("outputs", exist_ok=True)
io = NetCDF4Backend("outputs/deterministic_result.nc", backend_kwargs={"mode": "w"})
# 予報開始時刻と予報ステップ数を指定
forecast_time = "2025-01-01"
nsteps = 40 # 6時間間隔 × 40ステップ = 240時間予報
run.deterministic([forecast_time], nsteps, model, data, io)
- 重みデータは初回にダウンロードされ、2回目以降の実行ではダウンロード済みのものが読み込まれる。重みデータのダウンロードはデフォルトでは300秒(5分)でタイムアウトするため、時間がかかる場合は
EARTH2STUDIO_PACKAGE_TIMEOUT=900等で設定を変更する。 - データソースとしては
IFSの他にGFSやARCO(ERA5)などが選択できる - 出力先(IO)の
NetCDF4Backendでは、逐次でファイルに書きこめる代わりに、圧縮やチャンクの設定をすることができない。圧縮したい場合は、一旦XarrayBackendでxarray.Datasetを経由して.to_netcdf()するとよい(後述)。ただし出力は一旦メモリ上に蓄積されるのに注意。
Precipitation AFNO v2 で降水を診断する
FourCastNet v2自体には降水が出力変数として用意されていないが、出力された気象場からPrecipitation AFNOというモデルを用いて診断することができる。この一連のパイプライン処理はrun.diagnosticで行うことができる。
Precipitaiton AFNO v2のインストールは次のとおり。
uv add earth2studio --extra precip-afno-v2
最小限の実行可能コードは次のとおり。
import os
from earth2studio.data import IFS
from earth2studio.io import NetCDF4Backend
from earth2studio.models.px import SFNO
from earth2studio.models.dx import PrecipitationAFNOv2
import earth2studio.run as run
# 出力ディレクトリの作成
os.makedirs("outputs", exist_ok=True)
# 予報モデル(SFNO)のロード
package = SFNO.load_default_package()
prognostic_model = SFNO.load_model(package)
# 診断モデル(PrecipitationAFNOv2)のロード
package = PrecipitationAFNOv2.load_default_package()
diagnostic_model = PrecipitationAFNOv2.load_model(package)
# データソースの作成(IFSデータ)
data = IFS()
# 出力先(NetCDF4バックエンド)
io = NetCDF4Backend("outputs/diagnostic_result.nc", backend_kwargs={"mode": "w"})
# 予報開始時刻と予報ステップ数を指定
forecast_time = "2025-01-01"
nsteps = 40 # 6時間間隔 × 40ステップ = 240時間予報
io = run.diagnostic([forecast_time], nsteps, prognostic_model, diagnostic_model, data, io)
tp06という変数名で6時間降水量が m 単位で出力される。また、緯度座標は720点で0.25°格子のうち南極点が存在しない点が、予報モデルのFourCastNetとは異なる。
run.diagnosticを使用する場合、元のprognostic_model (FourCastNet) の変数をNetCDFに出力することはできない。これを行いたい場合については次の節で扱う。
FourCastNet v2 + Precipitation AFNO v2 の結果をまとめて出力する
予報モデルと診断モデルの結果の両方を保存したい場合は、以下のrun.diagnosticのコードを参考に、カスタムワークフローを定義する。
ポイントはループの中で予報モデルと診断モデルの出力結果に対してio.write()を呼ぶことである。
with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar:
for step, (x, coords) in enumerate(model):
# prognosticの出力を指定変数のみに絞り込んで書き込む
x_p, coords_p = map_coords(x, coords, output_coords)
io.write(*split_coords(x_p, coords_p))
# diagnosticの出力を書き込む
x_d, coords_d = map_coords(x, coords, diagnostic_ic)
x_d, coords_d = diagnostic_model(x_d, coords_d)
io.write(*split_coords(x_d, coords_d))
pbar.update(1)
if step == nsteps:
break
ソースコード全体は長いので折りたたんでおく。上記以外に、南極点の降水量の周囲の値での穴埋め、出力変数のフィルタや圧縮を行なっている。
ソースコード全体
import os
from collections import OrderedDict
import numpy as np
import torch
from tqdm import tqdm
from earth2studio.data import IFS, fetch_data
from earth2studio.io import XarrayBackend
from earth2studio.models.dx import PrecipitationAFNOv2
from earth2studio.models.px import SFNO
from earth2studio.utils.coords import map_coords, split_coords
from earth2studio.utils.time import to_time_array
def main():
# 出力ディレクトリの作成
os.makedirs("outputs", exist_ok=True)
# FourCastNet(予報モデル)のロード
print("FourCastNetモデルをロード中...")
fcn_package = SFNO.load_default_package()
prognostic_model = SFNO.load_model(fcn_package)
# PrecipitationAFNOv2(診断モデル)のロード
print("PrecipitationAFNOv2モデルをロード中...")
precip_package = PrecipitationAFNOv2.load_default_package()
diagnostic_model = PrecipitationAFNOv2.load_model(precip_package)
# データソースの作成(IFSデータ)
data = IFS()
# 出力先(Xarrayバックエンド)
io = XarrayBackend()
# 予報開始時刻と予報ステップ数を指定
forecast_time = "2025-01-01"
nsteps = 40 # 6時間間隔 × 40ステップ = 240時間予報
# デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Inference device: {device}")
prognostic_model = prognostic_model.to(device)
diagnostic_model = diagnostic_model.to(device)
# 入力座標の取得
prognostic_ic = prognostic_model.input_coords()
diagnostic_ic = diagnostic_model.input_coords()
time_arr = to_time_array([forecast_time])
# 補間設定
if hasattr(prognostic_model, "interp_method"):
interp_to = prognostic_ic
interp_method = prognostic_model.interp_method
else:
interp_to = None
interp_method = "nearest"
# データ取得(1回だけ)
x0, coords0 = fetch_data(
source=data,
time=time_arr,
variable=prognostic_ic["variable"],
lead_time=prognostic_ic["lead_time"],
device=device,
interp_to=interp_to,
interp_method=interp_method,
)
# 出力変数の制限(必要な変数のみを出力)
output_coords = {
"variable": np.array(
[
"z500",
"u200",
"u500",
"u850",
"v200",
"v500",
"v850",
"t500",
"t700",
"t850",
"q500",
"q700",
"q850",
"msl",
"u10m",
"v10m",
"t2m",
"tcwv",
]
)
}
# IO座標のセットアップ(time/lead_timeは全ステップ分を用意)
prognostic_oc = OrderedDict(prognostic_model.output_coords(prognostic_ic)).copy()
diagnostic_oc = OrderedDict(diagnostic_model.output_coords(diagnostic_ic))
# output_coordsで指定された変数でprognostic_ocを上書き
for key, value in prognostic_oc.items():
prognostic_oc[key] = output_coords.get(key, value)
# バッチなどのダミー次元を除去(run.deterministic と同様)
for key, value in list(prognostic_oc.items()):
if value.shape == (0,):
del prognostic_oc[key]
# lead_time を全ステップ分に展開(run.deterministic と同様)
dt = prognostic_oc["lead_time"]
prognostic_oc["time"] = time_arr
prognostic_oc["lead_time"] = np.asarray(
[dt * i for i in range(nsteps + 1)]
).flatten()
prognostic_oc.move_to_end("lead_time", last=False)
prognostic_oc.move_to_end("time", last=False)
# 保存する変数名を union する(衝突した場合はFCN側を優先)
p_vars = list(prognostic_oc["variable"])
d_vars = list(diagnostic_oc["variable"])
var_names = p_vars + [v for v in d_vars if v not in p_vars]
coords_for_io = prognostic_oc.copy()
coords_for_io.pop("variable")
io.add_array(coords_for_io, var_names)
# Map lat/lon if needed
x0, coords0 = map_coords(x0, coords0, prognostic_ic)
# 予報を1回だけ回し、同じ状態からdiagnosticを計算して書き込む
model = prognostic_model.create_iterator(x0, coords0)
with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar:
for step, (x, coords) in enumerate(model):
# prognosticの出力を指定変数のみに絞り込んで書き込む
x_p, coords_p = map_coords(x, coords, output_coords)
io.write(*split_coords(x_p, coords_p))
# diagnosticの出力を書き込む
x_d, coords_d = map_coords(x, coords, diagnostic_ic)
x_d, coords_d = diagnostic_model(x_d, coords_d)
io.write(*split_coords(x_d, coords_d))
pbar.update(1)
if step == nsteps:
break
# 南極点を周囲の値で穴埋めする
ds = io.root
ds["tp06"][:, :, -1, :] = ds["tp06"][:, :, -2, :].mean(dim="lon")
encoding_common = {"zlib": True, "complevel": 5, "shuffle": True, "dtype": "int16"}
g = 9.80665
encoding_z500 = {**encoding_common, "scale_factor": g, "add_offset": 5000 * g}
encoding_msl = {**encoding_common, "scale_factor": 10, "add_offset": 100000}
encoding_wind = {**encoding_common, "scale_factor": 0.1}
encoding_temp = {**encoding_common, "scale_factor": 0.1, "add_offset": 273.15}
encoding_q500 = {**encoding_common, "scale_factor": 5e-6}
encoding_q700 = {**encoding_common, "scale_factor": 1e-5}
encoding_q850 = {**encoding_common, "scale_factor": 2e-5}
encoding_tcwv = {**encoding_common, "scale_factor": 0.1}
encoding_tp06 = {**encoding_common, "scale_factor": 1e-4}
ds.isel(lead_time=slice(1, 41)).to_netcdf(
"outputs/result.nc",
encoding={
"z500": encoding_z500,
"u200": encoding_wind,
"u500": encoding_wind,
"u850": encoding_wind,
"v200": encoding_wind,
"v500": encoding_wind,
"v850": encoding_wind,
"t500": encoding_temp,
"t700": encoding_temp,
"t850": encoding_temp,
"q500": encoding_q500,
"q700": encoding_q700,
"q850": encoding_q850,
"msl": encoding_msl,
"u10m": encoding_wind,
"v10m": encoding_wind,
"t2m": encoding_temp,
"tcwv": encoding_tcwv,
"tp06": encoding_tp06,
},
)
ds.close()
if __name__ == "__main__":
main()
熱帯低気圧(TC)のトラッキングをする
熱帯低気圧(TC)のトラッキングをすることができる。なおCuPyに依存しているため、ROCm環境ではそのままでは動かない。
以下に上記コードからの差分を記述する。
まず追加パッケージのインポートを行う。Polarsは最後に表データに変換するために使っている。
import polars as pl
from earth2studio.models.dx import TCTrackerWuDuan, TCTrackerVitart
次にモデルをロードする。
# TCトラッカーのロード("wuduan" または "vitart" を選択)
tracker_type = "wuduan" # or "vitart"
print(f"TCトラッカー ({tracker_type}) をロード中...")
if tracker_type == "wuduan":
tracker = TCTrackerWuDuan().to(device)
else:
tracker = TCTrackerVitart().to(device)
推論ループ内で、トラッキングを実行する。tc_outputには過去のトラッキングの情報が蓄積されていく。
with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar:
for step, (x, coords) in enumerate(model):
# prognosticの出力を指定変数のみに絞り込んで書き込む
x_p, coords_p = map_coords(x, coords, output_coords)
io.write(*split_coords(x_p, coords_p))
# diagnosticの出力を書き込む
x_d, coords_d = map_coords(x, coords, diagnostic_ic)
x_d, coords_d = diagnostic_model(x_d, coords_d)
io.write(*split_coords(x_d, coords_d))
# TCトラッキング実行
x_t, coords_t = map_coords(x, coords, tracker.input_coords())
tc_output, tc_coords = tracker(x_t, coords_t)
# lead_time次元を削除
tc_output = tc_output[:, 0]
pbar.update(1)
if step == nsteps:
break
ループの外側でtc_outputからトラッキングの結果を取り出す。元の出力はtorch.Tensorであるため、扱いやすいようにPolars経由でParquetに変換している。出力されるカラムはこの場合、batch, path_id, step, tc_lat, tc_lon, tc_msl, tc_w10mである。この部分はお好みで好きな形式にするとよいだろう。
# TCトラッキング結果をDataFrameに変換して保存
tc_data = tc_output.cpu().numpy()
tc_var_names = list(tc_coords["variable"])
# [batch, path, step, variable] -> フラットなDataFrameに変換
rows = []
for b in range(tc_data.shape[0]):
for p in range(tc_data.shape[1]):
for s in range(tc_data.shape[2]):
row = {"batch": b, "path_id": p, "step": s}
for v, var_name in enumerate(tc_var_names):
row[var_name] = tc_data[b, p, s, v]
rows.append(row)
df = pl.DataFrame(rows)
# NaN行を除去(トラッキングがない部分)
df = df.drop_nans(subset=tc_var_names)
df.write_parquet("outputs/tc_tracks.parquet")
print(
f"TCトラッキング結果を outputs/tc_tracks.parquet に保存しました({len(df)} rows)"
)
FuXiを実行する
FuXiは他のモデルと以下の点が異なるため、run.deterministicでそのまま推論することができない。
- 湿度が比湿ではなく相対湿度である
- 入力変数に降水量が含まれている
このうち1.については、model.dx.DerivedRHという気温と比湿から相対湿度を作成する診断モデルを利用することができる。
また2.については、値0を入力することで実行できる。これがFuXiでの正しい取り扱いであるかは不明だが、後継バージョンのFuXi2ではThe accumulated variables ('ssr', 'ssrd', 'fdir', 'ttr', 'tp') are not needed for input and can be set to zero. とあり、診断変数の初期値は0としている。
実際のコードは長いため以下に折りたたんで示す。
コード全体
import os
from collections import OrderedDict
from datetime import datetime
import numpy as np
import torch
from earth2studio.data import GFS, fetch_data
from earth2studio.io import XarrayBackend
from earth2studio.models.dx.derived import DerivedRH
from earth2studio.models.px import FuXi
from earth2studio.utils.coords import map_coords, split_coords
from earth2studio.utils.time import to_time_array
# FuXiの圧力レベル
FUXI_PRESSURE_LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
# 設定
forecast_time = "2025-01-01T00:00:00" # 予報開始時刻
nsteps = 4 # 予報ステップ数(6時間×4=24時間)
os.makedirs("output", exist_ok=True)
# デバイス設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# FuXiモデルのロード
print("FuXiモデルをロード中...")
fuxi_package = FuXi.load_default_package()
model = FuXi.load_model(fuxi_package, cascade_models=False).to(device)
# DerivedRH(比湿→相対湿度変換)
rh_model = DerivedRH(levels=FUXI_PRESSURE_LEVELS).to(device)
# GFSデータソース
data = GFS()
# FuXi入力座標を取得
input_coords = model.input_coords()
time_arr = to_time_array([forecast_time])
# 取得変数リスト作成(r* → q*、tp06は除外)
fetch_variables = []
for var in input_coords["variable"]:
if var.startswith("r") and var[1:].isdigit():
fetch_variables.append(f"q{var[1:]}")
elif var == "tp06":
continue # 降水量は後で0を設定
else:
fetch_variables.append(var)
# 補間設定
interp_to = OrderedDict({"_lat": input_coords["lat"], "_lon": input_coords["lon"]})
# データ取得
print("GFSデータを取得中...")
x_fetch, coords_fetch = fetch_data(
source=data,
time=time_arr,
variable=fetch_variables,
lead_time=input_coords["lead_time"],
device=device,
interp_to=interp_to,
interp_method="linear",
)
# 相対湿度を計算
print("相対湿度を計算中...")
rh_input_vars = []
for level in FUXI_PRESSURE_LEVELS:
rh_input_vars.extend([f"t{level}", f"q{level}"])
x_rh_in, coords_rh_in = map_coords(
x_fetch, coords_fetch, {"variable": np.array(rh_input_vars)}
)
x_rh, coords_rh = rh_model(x_rh_in, coords_rh_in)
# FuXi入力を構築
print("FuXi入力を構築中...")
var_dict = {
var: x_fetch[..., i : i + 1, :, :]
for i, var in enumerate(coords_fetch["variable"])
}
# 相対湿度を追加
for i, level in enumerate(FUXI_PRESSURE_LEVELS):
var_dict[f"r{level}"] = x_rh[..., i : i + 1, :, :]
# tp06に0を設定
tp06_shape = list(x_fetch.shape)
tp06_shape[-3] = 1
var_dict["tp06"] = torch.zeros(tp06_shape, dtype=x_fetch.dtype, device=device)
# 変数順に結合
x0 = torch.cat([var_dict[var] for var in input_coords["variable"]], dim=-3)
coords0 = OrderedDict(
{
"time": time_arr,
"lead_time": input_coords["lead_time"],
"variable": input_coords["variable"],
"lat": input_coords["lat"],
"lon": input_coords["lon"],
}
)
# 出力設定
io = XarrayBackend()
output_coords = OrderedDict(model.output_coords(input_coords)).copy()
for key, value in list(output_coords.items()):
if value.shape == (0,):
del output_coords[key]
dt = output_coords["lead_time"]
output_coords["time"] = time_arr
output_coords["lead_time"] = np.asarray(
[dt * i for i in range(nsteps + 1)]
).flatten()
output_coords.move_to_end("lead_time", last=False)
output_coords.move_to_end("time", last=False)
var_names = list(output_coords["variable"])
io_coords = output_coords.copy()
io_coords.pop("variable")
io.add_array(io_coords, var_names)
# 推論実行
print(f"推論実行中({nsteps}ステップ)...")
iterator = model.create_iterator(x0, coords0)
for step, (x, coords) in enumerate(iterator):
io.write(*split_coords(x, coords))
print(f" Step {step}/{nsteps} 完了")
if step == nsteps:
break
# 結果保存
output_file = "output/simple_fuxi_gfs.nc"
io.root.isel(lead_time=slice(1, nsteps + 1)).to_netcdf(output_file)
print(f"結果を {output_file} に保存しました。")
io.root.close()
AIFS-ENSを実行する
AIFS-ENSのようなアンサンブルモデルを実行する場合、run.ensembleワークフローを使用する。
import os
import numpy as np
from earth2studio.models.px import AIFSENS
from earth2studio.data import IFS
from earth2studio.io import NetCDF4Backend
from earth2studio.perturbation import Zero
from earth2studio.run import ensemble as run
# VRAM使用量の削減
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
os.environ["ANEMOI_INFERENCE_NUM_CHUNKS"] = "16"
model = AIFSENS.load_model(AIFSENS.load_default_package())
out_vars = ["u10m", "v10m", "t2m", "msl", "tcw"]
data = IFS()
io = NetCDF4Backend("aifs-ens.nc", backend_kwargs={"mode": "w"})
perturbation = Zero()
run(
time=["2025-01-01T00:00:00"],
nsteps=40,
nensemble=4,
prognostic=model,
data=data,
io=io,
perturbation=perturbation,
batch_size=1,
output_coords={"variable": np.array(out_vars)},
)
モデルの重みデータのキャッシュ先
モデルの重みデータは~/.cache/earth2studio/下にモデル別にダウンロードされキャッシュされる。ダウンロードに時間がかかるため、同じネットワーク内の別の計算機で推論する場合にはコピーして使うとよい。