PINN(Physics-Informed Neural Network)は「少ないデータでも、物理法則(微分方程式)を使って学習を補助する」枠組みです。
さらに 未知パラメータ(例:速度定数 k)も同時に推定する逆問題にすると、研究現場での使い所が一気に増えます。
そして研究者が次に欲しくなるのがこれ:
- 推定した k_hat は、どれくらい確からしいのか?(不確実性)
- 次に測るなら、どの時刻 t が最も情報を増やせるか?(実験計画)
この記事では、厳密なベイズ推論(Bayesian PINN)に飛び込む前に、実務で回しやすい入口として
- Deep Ensemble(seed違いで複数回学習)=“Bayesian PINN風”
- 感度(du/dk)から Fisher情報っぽい指標で「次に測るべき t」を提案
という最小セットを Google Colab で再現します。
注意:ここでの「Bayesian PINN風」は “厳密なベイズ事後分布”ではありません。
Deep Ensembleのばらつきには「学習の不安定性・局所解・データ不足」などが混ざります。
それでも「次に測るべき点を決める」「不確実性の気配を見る」用途では、現場で役立つことがあります。
TL;DR
- 逆問題PINNで u(t) と未知パラメータ k を同時推定できる
- Deep Ensembleで k_hat のばらつきを見て“Bayesian-ish”な不確実性の手がかりを得る
- 「次に測るべき t」は k同定に効く感度(du/dk)から提案できる
(指数減衰では 概ね t ≈ 1/k 付近が効きやすい)
1. 用語の整理(初心者向け)
1.1 PINNとは?
ニューラルネットで未知関数 u(t) を近似しつつ、
「データに合う」だけでなく「微分方程式も満たす」ように学習する方法です。
代表的な損失は以下です:
- データ損失:観測値に合う(例:MSE)
- 物理損失(残差):微分方程式の残差がゼロに近い
-
初期条件/境界条件損失:
u(0)=1などを守る
これらを足した損失を最小化します。
1.2 逆問題PINNとは?
微分方程式の パラメータ(ここでは k)も未知として、
u(t) と k を同時に推定する問題です。
1.3 Bayesian PINN“風”とは?
厳密なベイズ推論をせず、まずは簡単に不確実性の気配を出すアイデアです。
- 同じデータで seedだけ変えて複数回学習
- 得られた
k_hatの分布(ばらつき)を 不確実性の代理として扱う
これを Deep Ensemble と呼びます。
2. 今回のトイ問題(指数減衰 + 未知k)
最もシンプルな逆問題として指数減衰を扱います。
-
真のモデル(トイなので真値がある)
u(t) = exp(-k t) -
微分方程式(物理法則)
u'(t) + k u(t) = 0 -
初期条件
u(0) = 1
観測は ノイズ付きで少数点のみ、という状況を作ります。
3. Colabコード(動作版:最初から全部)
3.1 セル1:セットアップ&ダミーデータ作成
# === Setup ===
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
# 重要:dtypeを統一(float32推奨)
DTYPE = torch.float32
torch.set_default_dtype(DTYPE)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
os.makedirs("fig", exist_ok=True)
# === Toy data (exponential decay with unknown k) ===
np.random.seed(0)
torch.manual_seed(0)
k_true = 1.5 # 真のk(トイなので知っている)
t_max = 2.0
sigma_obs = 0.03
n_obs = 12
# seed=0 の例だと、たまたま観測時刻が 0.75〜2.0 に寄り、
# 「初期のデータが無い」状況になりやすい(それが今回の学びポイントにもなる)
t_obs = np.sort(np.random.rand(n_obs).astype(np.float32) * t_max)
u_true_obs = np.exp(-k_true * t_obs).astype(np.float32)
u_obs = (u_true_obs + np.random.randn(n_obs).astype(np.float32) * sigma_obs).astype(np.float32)
# 可視化用の真の曲線
t_grid = np.linspace(0, t_max, 401, dtype=np.float32)
u_true_grid = np.exp(-k_true * t_grid).astype(np.float32)
plt.figure(figsize=(6.8,4.2))
plt.plot(t_grid, u_true_grid, label="ground truth (toy)")
plt.scatter(t_obs, u_obs, s=70, label="noisy observations")
plt.xlabel("t")
plt.ylabel("u(t)")
plt.title("Toy data: exponential decay with unknown k")
plt.legend()
plt.tight_layout()
plt.savefig("fig/fig1_toy_data.png", dpi=160)
plt.show()
print("Saved: fig/fig1_toy_data.png")
print("t_obs min/max:", t_obs.min(), t_obs.max())
3.2 セル2:PINN(逆問題)のモデル&学習関数
# === Model ===
class MLP(nn.Module):
def __init__(self, width=32, depth=3):
super().__init__()
layers = []
in_dim = 1
for _ in range(depth):
layers.append(nn.Linear(in_dim, width))
layers.append(nn.Tanh())
in_dim = width
layers.append(nn.Linear(in_dim, 1))
self.net = nn.Sequential(*layers)
def forward(self, t):
return self.net(t)
# === Inverse PINN trainer ===
def train_inverse_pinn(
t_obs_np: np.ndarray,
u_obs_np: np.ndarray,
t_max: float,
seed: int = 0,
steps: int = 1500,
lr: float = 1e-3,
n_phys: int = 64,
w_data: float = 1.0,
w_phys: float = 1.0,
w_ic: float = 10.0,
width: int = 32,
depth: int = 3,
verbose_every: int = 500,
):
"""
Solve inverse problem:
u'(t) + k u(t) = 0, u(0)=1
with unknown k (estimated simultaneously).
"""
np.random.seed(seed)
torch.manual_seed(seed)
model = MLP(width=width, depth=depth).to(device)
# k>0 を保証するため log_k を学習
log_k = nn.Parameter(torch.tensor([0.0], dtype=DTYPE, device=device))
opt = torch.optim.Adam(list(model.parameters()) + [log_k], lr=lr)
# data tensors (float32!)
t_data = torch.tensor(t_obs_np.reshape(-1,1), dtype=DTYPE, device=device)
u_data = torch.tensor(u_obs_np.reshape(-1,1), dtype=DTYPE, device=device)
# initial condition
t0 = torch.tensor([[0.0]], dtype=DTYPE, device=device)
u0 = torch.tensor([[1.0]], dtype=DTYPE, device=device)
hist = {"step": [], "total": [], "data": [], "phys": [], "ic": [], "k": []}
for step in range(steps):
opt.zero_grad()
k = torch.exp(log_k) # positive
# data loss
u_pred = model(t_data)
loss_data = torch.mean((u_pred - u_data)**2)
# physics loss (collocation points)
t_phys = (torch.rand(n_phys, 1, dtype=DTYPE, device=device) * float(t_max)).requires_grad_(True)
u_phys = model(t_phys)
du_dt = torch.autograd.grad(
u_phys, t_phys,
grad_outputs=torch.ones_like(u_phys),
create_graph=True
)[0]
r = du_dt + k * u_phys
loss_phys = torch.mean(r**2)
# initial condition loss
u0_pred = model(t0)
loss_ic = torch.mean((u0_pred - u0)**2)
loss = w_data*loss_data + w_phys*loss_phys + w_ic*loss_ic
loss.backward()
opt.step()
if (step % verbose_every == 0) or (step == steps - 1):
hist["step"].append(step)
hist["total"].append(loss.item())
hist["data"].append(loss_data.item())
hist["phys"].append(loss_phys.item())
hist["ic"].append(loss_ic.item())
hist["k"].append(k.item())
print(f"[seed={seed}] step={step:4d} total={loss.item():.6f} "
f"(data={loss_data.item():.6f}, phys={loss_phys.item():.6f}, ic={loss_ic.item():.6f}) "
f"k_hat={k.item():.4f}")
# prediction grid
t_grid = np.linspace(0, t_max, 401, dtype=np.float32)
with torch.no_grad():
u_grid = model(torch.tensor(t_grid.reshape(-1,1), dtype=DTYPE, device=device)) \
.cpu().numpy().reshape(-1).astype(np.float32)
k_hat = float(torch.exp(log_k).detach().cpu().numpy())
return k_hat, t_grid, u_grid, hist
3.3 セル3:まず1回だけ学習して動作確認(kの収束と再構成)
k_hat, t_grid, u_hat_grid, hist = train_inverse_pinn(
t_obs, u_obs, t_max=t_max,
seed=0,
steps=1500,
n_phys=64,
width=32, depth=3,
verbose_every=500
)
print("\nTrue k =", k_true)
print("Estimated k_hat =", k_hat)
# (1) kの推移
plt.figure(figsize=(6.8,4.2))
plt.plot(hist["step"], hist["k"], marker="o")
plt.axhline(k_true, linestyle="--", label="k_true (toy)")
plt.xlabel("step")
plt.ylabel("k_hat")
plt.title("k_hat convergence (single run)")
plt.legend()
plt.tight_layout()
plt.savefig("fig/fig2_k_convergence_single.png", dpi=160)
plt.show()
# (2) 再構成(単発)
plt.figure(figsize=(6.8,4.2))
plt.plot(t_grid, u_true_grid, label="ground truth (toy)")
plt.scatter(t_obs, u_obs, s=70, label="noisy obs")
plt.plot(t_grid, u_hat_grid, label=f"PINN inverse (k_hat={k_hat:.3f})")
plt.xlabel("t")
plt.ylabel("u(t)")
plt.title("Reconstruction (single inverse PINN)")
plt.legend()
plt.tight_layout()
plt.savefig("fig/fig3_reconstruction_single.png", dpi=160)
plt.show()
print("Saved: fig/fig2_k_convergence_single.png")
print("Saved: fig/fig3_reconstruction_single.png")
3.4 セル4:Deep Ensemble(Bayesian PINN“風”)で k_hat の不確実性を出す
n_ens = 10
steps = 1200
k_hats = []
u_grids = []
for i in range(n_ens):
seed = 100 + i
k_i, t_grid, u_i, _ = train_inverse_pinn(
t_obs, u_obs, t_max=t_max,
seed=seed,
steps=steps,
n_phys=64,
width=32, depth=3,
verbose_every=steps # 最後だけ表示
)
k_hats.append(k_i)
u_grids.append(u_i)
# 重要:listのまま演算しない
k_hats = np.asarray(k_hats, dtype=np.float32)
u_grids = np.asarray(u_grids, dtype=np.float32) # (n_ens, len(t_grid))
k_mean = float(k_hats.mean())
k_std = float(k_hats.std(ddof=1))
print("\n=== Ensemble summary ===")
print("True k =", k_true)
print(f"k_hat mean ± std = {k_mean:.4f} ± {k_std:.4f}")
# k_hat ヒストグラム
plt.figure(figsize=(6.8,4.2))
plt.hist(k_hats, bins=10, alpha=0.85)
plt.axvline(k_true, linestyle="--", label="k_true (toy)")
plt.xlabel("k_hat")
plt.ylabel("count")
plt.title("Deep ensemble: k_hat distribution (Bayesian-ish)")
plt.legend()
plt.tight_layout()
plt.savefig("fig/fig4_k_hist.png", dpi=160)
plt.show()
# u(t) の予測平均と不確実性(±2σ)
u_mean = u_grids.mean(axis=0)
u_std = u_grids.std(axis=0, ddof=1)
plt.figure(figsize=(6.8,4.2))
plt.plot(t_grid, u_true_grid, label="ground truth (toy)")
plt.scatter(t_obs, u_obs, s=70, label="noisy obs")
plt.plot(t_grid, u_mean, label="ensemble mean")
plt.fill_between(t_grid, u_mean-2*u_std, u_mean+2*u_std, alpha=0.25, label="±2σ (ensemble)")
plt.xlabel("t")
plt.ylabel("u(t)")
plt.title("Bayesian-ish uncertainty band (deep ensemble)")
plt.legend()
plt.tight_layout()
plt.savefig("fig/fig5_ensemble_band.png", dpi=160)
plt.show()
print("Saved: fig/fig4_k_hist.png")
print("Saved: fig/fig5_ensemble_band.png")
3.5 セル5:「次に測るべき t」を提案(k同定の情報量っぽい指標)
指数減衰 u(t)=exp(-k t) では、kに対する感度は
du/dk = -t exp(-k t) です。
観測ノイズ分散を sigma_obs^2 とすると、Fisher情報っぽい量は(比例で)
I(t) ∝ (du/dk)^2 / sigma_obs^2 = t^2 exp(-2 k t) / sigma_obs^2
kが不確かなときは、アンサンブルで平均して「期待情報量」っぽくします。
t_cand = t_grid.copy() # float32
sigma2 = float(sigma_obs**2)
# I(t) ~ E_k[t^2 * exp(-2 k t)] / sigma^2
info = ((t_cand[None,:]**2) * np.exp(-2.0 * k_hats[:,None] * t_cand[None,:])).mean(axis=0) / sigma2
# 既に測った点に近すぎるtは避ける(「新しいtを測る」前提)
delta = 0.03
mask = np.ones_like(t_cand, dtype=bool)
for t in t_obs:
mask &= (np.abs(t_cand - t) > delta)
t_next = float(t_cand[mask][np.argmax(info[mask])])
print("Suggested next t (for identifying k) =", t_next)
# plot info curve
plt.figure(figsize=(6.8,4.2))
plt.plot(t_cand, info, label="info score (Fisher-like)")
for t in t_obs:
plt.axvline(t, alpha=0.15)
plt.axvline(t_next, linestyle="--", label=f"suggested t={t_next:.3f}")
plt.xlabel("t")
plt.ylabel("info score (relative)")
plt.title("Where to measure next? (simple criterion for k)")
plt.legend()
plt.tight_layout()
plt.savefig("fig/fig6_next_t_info.png", dpi=160)
plt.show()
# u(t)の“不確実性が大きい場所”(予測の揺れ)
plt.figure(figsize=(6.8,4.2))
plt.plot(t_grid, u_std, label="u(t) ensemble std")
for t in t_obs:
plt.axvline(t, alpha=0.15)
plt.axvline(t_next, linestyle="--", label=f"suggested t={t_next:.3f}")
plt.xlabel("t")
plt.ylabel("std (ensemble)")
plt.title("Uncertainty over time (ensemble std)")
plt.legend()
plt.tight_layout()
plt.savefig("fig/fig7_uncertainty_over_t.png", dpi=160)
plt.show()
print("Saved: fig/fig6_next_t_info.png")
print("Saved: fig/fig7_uncertainty_over_t.png")
4. 実行して得られる図
この実行では、以下の図が得られます。
5. 結果の分析と考察(今回の図をどう読むか)
5.1 観測時刻が「初期を外している」のが効いている(図1)
図1を見ると、観測点が t≈0.75〜2.0 に集中していて、t<0.7 にデータがありません。
指数減衰で k を同定するとき、直感的には
- t が小さすぎる:まだ減衰しておらず、kの影響が見えにくい
- t が大きすぎる:u(t) が小さくなり、観測ノイズに埋もれやすい
ので、「ちょうど良い中間」が欲しくなります。
この“中間の重要性”が、後の t≈0.8 提案に繋がります。
5.2 単発の逆PINNは「データには合う」が「kはズレうる」(図2・図3)
図2・図3より、単発の逆PINNは
- 観測点(ノイズ付き)にはよくフィットする
- しかし推定された k_hat は真値 k_true より 小さめになりやすい
→ 図3の予測曲線が真値より 少し上側(減衰が遅い) になる
という挙動でした。
ここが逆問題の怖いところで、
「観測に合う」≠「パラメータが正しい」
が普通に起きます。
特に、観測が遅い時間帯に偏ると(今回の図1)、kの同定が難しくなります。
5.3 Deep Ensembleは「ばらつき」だけでなく「偏り」も見せてくれる(図4)
図4(k_hatヒストグラム)では、k_hat が ある範囲に分布します。
ここで重要なのは2点です。
(A) ばらつき(=Bayesian-ish不確実性の手がかり)
同じデータでも seed を変えると k_hat が揺れます。
この揺れは「不確実性の気配」を与えます。
(B) 今回の実行では “真値より小さめに偏っている”
図4の分布が k_true=1.5 より低い側に寄っていました。
これは「単に不確実」だけでなく、
- データ配置がk同定に不利(図1)
- モデル・損失・最適化の条件で一方向に引っ張られる
- そもそもこのデータでは k が十分同定できない(識別性が弱い)
などの可能性を示します。
つまり今回の図4は
「ばらつきが小さい=信頼できる」ではない
という教訓にもなります(“自信満々に外す”があり得る)。
5.4 u(t) の ±2σ バンドが「真値を覆わない」場合の解釈(図5)
図5では ensemble mean ±2σ を描いていますが、今回の実行では
- 平均が真値から少しズレる(k_hat が小さい)
- ±2σ もそれほど大きくない
ため、真値曲線がバンドに入らない領域がありました。
これは「誤り」ではなく、Deep Ensembleの性質として自然です:
- Deep Ensembleは “学習の揺れ” を拾う
- モデルの系統誤差(バイアス) は別問題
したがってこの記事では、±2σ を “信用区間”だと思わないことを明確にします。
5.5 「次に測るべき t」=k同定に効く時刻(図6)
図6の info score は、指数減衰の感度 du/dk = -t exp(-k t) から作っています。
-
I(t) ∝ t^2 exp(-2 k t)は、t=1/k で最大になります - 今回の実行では t≈0.8 が提案されました
これは、アンサンブルの k_hat が 1.2〜1.4 あたりに集まる
→ 1/k が 0.7〜0.8 付近
という状況と整合します。
5.6 「u(t)の不確実性が大きい場所」と「k同定に効くt」は一致しない(図7)
図7では u(t) の ensemble std が t≈0.4 付近で最大でした。
一方で、k同定の提案 t は 0.8 でした。
ここは誤解されがちなので、言い切っておきます:
- u(t) を正確に再構成したい → “uの不確実性が大きい”場所を測るのが合理的
- k を正確に同定したい → “k感度が最大”の場所を測るのが合理的(図6)
目的が違うので、提案点が違って当然です。
研究では「何を改善したいのか(uかkか、あるいは両方か)」をまず明確にするのが重要です。
6. この記事を結果に合わせて“安全運転”にするための注意(重要)
今回の実行結果(図4・図5の“真値より低めに偏る”)を踏まえ、読者が誤解しないように、次を明記します。
- Deep Ensembleの分布は ベイズ事後分布ではない
- 分布が真値を外すなら、「不確実」以前に 識別性やバイアスを疑う
- その場合のアクションは
- 測定点の追加(図6で提案された t など)
- 観測ノイズの見直し(繰り返し測定・SNR改善)
- 学習設定の見直し(重み、collocation、最適化手順)
-
より確率的な方法(MC Dropout / Laplace / SWAG / VI 等)
のどれか(あるいは複合)
7. 次の一歩(研究で“使える”形にする拡張)
7.1 追加測定→再学習→再提案(Active Learningループ)
実務での王道はこれです。
- いまのデータで逆PINNを回す
- 不確実性やinfoで次点を提案する(図6)
- 追加測定する
- 再学習して k_hat 分布がどう変わるか見る
- 収束するまで繰り返す
7.2 目的が「k」なのか「曲線再構成」なのかを分ける
図6と図7がズレたのは“バグ”ではなく仕様です。
どちらの目的を最適化するかで設計指標を選びます。
まとめ
- 逆問題PINNで u(t)とkを同時推定できる
- Deep Ensembleで 不確実性の気配を出せるが、偏り(バイアス)も起きる
- 「次に測るべき t」は、目的(k同定 or 曲線再構成)に合わせて設計する
- 今回の実行結果は、観測が遅い時刻に偏ると k 推定が偏りやすいこと、
そして “次に測るt”の提案が重要になることを分かりやすく示している
この記事の狙い
- “厳密なベイズ”に行く前に、研究現場で回せる最小セットを手元で動かす
- 実験計画(次に測る点)まで含めて、モデルとデータ収集を同じ土俵で議論する













