1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

導入:PINNで「状態」と「未知パラメータ」を同時に推定する(逆問題)

物理・化学・生体のモデルでは、微分方程式の形は分かっていても パラメータ(反応速度定数、減衰率、拡散係数…)が未知 という状況がよくあります。

たとえば、

  • 生理学:代謝・薬物動態の消失速度定数
  • 生物物理:緩和・回復の時定数
  • 医工学:装置や回路の定数、材料の係数

こうした「モデルはあるが定数が分からない」問題は、典型的な 逆問題 です。

本記事では、PINN(Physics-Informed Neural Network)で

  • 関数(状態)u(t)(時間に対する量の変化)
  • 未知パラメータk(減衰率)

同時に推定 する最小例を、Google Colabで再現できる形でまとめます。


TL;DR

  • PINNは「データに合う」だけでなく「方程式を満たす」ように学習するNN
  • 逆問題PINNでは 未知パラメータ(ここでは k)も学習変数にする
  • 今回の実行例では True k=1.5 に対して k_hat≈1.426 と推定(約5%低め)
  • データ損失がノイズで頭打ちになると、k は完全一致しないことがある(むしろ自然)
  • 改善策:重み付け、データ量、コロケーション点、最適化(Adam→LBFGS)など

1. PINNとは?(超ざっくり)

通常のNN回帰は「観測データに合うように」学習します。

一方PINNは、これに加えて

  • 微分方程式の残差(ズレ) が小さくなるように学習

します。

つまり損失がだいたい次の形になります:

  • データ損失:観測 u_obs(t) と予測 u_theta(t) のズレ
  • 物理損失:方程式残差 r(t) のズレ(理想は0)
  • 初期条件/境界条件損失:条件が満たされているか

2. 逆問題PINNとは?(未知パラメータも学習する)

逆問題PINNでは、NNの重み theta だけでなく、

  • k のような未知パラメータも 学習対象(trainable parameter)

にします。

学習で求めたいのは

  • u_theta(t)(関数)
  • k_hat(定数)

両方 です。


3. 最小例:指数減衰(kが未知)

今回は最小例として、次の1次常微分方程式を使います:

  • 方程式:u'(t) + k u(t) = 0
  • 初期条件:u(0)=1

解は(この玩具例では)解析的に分かっていて:

  • u(t) = exp(-k t)

ただし、実際の逆問題では真の k は未知なので、
ここでは ダミーデータ生成のためだけ に真値を使います。


4. 以下の5.のコードを実行して得られる図一覧

  • 図1:観測データ(ノイズ付き)と真の曲線
  • 図2:復元結果(PINN推定)と観測の比較
  • 図3:推定パラメータ k_hat の収束
  • 図4:損失の推移(total/data/phys/ic)
  • 図5:物理残差 r(t)=u'(t)+k u(t)(理想は0付近)

5. Google Colab用コード(そのまま実行OK)

このコードは「記事で理解するための最小構成」です。
実務では、スケーリング・重み付け・最適化・不確実性評価などを追加します(後述)。

# ====== Inverse PINN demo: estimate unknown k in u'(t) + k u(t) = 0 ======
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# ---- reproducibility ----
np.random.seed(0)
torch.manual_seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---- toy ground truth (only for generating dummy data) ----
k_true = 1.5
u0 = 1.0
t_min, t_max = 0.0, 2.0

def u_true_fn(t):
    return np.exp(-k_true * t)

# ---- observations (sparse + noisy) ----
n_obs = 12
t_obs = np.linspace(t_min, t_max, n_obs)
u_true = u_true_fn(t_obs)

noise_sigma = 0.03
u_obs = u_true + np.random.normal(0, noise_sigma, size=n_obs)

t_obs_t = torch.tensor(t_obs, dtype=torch.float32, device=device).view(-1,1)
u_obs_t = torch.tensor(u_obs, dtype=torch.float32, device=device).view(-1,1)

# ---- collocation points for physics loss ----
n_f = 200
t_f = np.random.uniform(t_min, t_max, size=n_f)
t_f_t = torch.tensor(t_f, dtype=torch.float32, device=device).view(-1,1)
t_f_t.requires_grad_(True)

# ---- simple MLP for u(t) ----
class MLP(nn.Module):
    def __init__(self, width=64, depth=3):
        super().__init__()
        layers = [nn.Linear(1, width), nn.Tanh()]
        for _ in range(depth-1):
            layers += [nn.Linear(width, width), nn.Tanh()]
        layers += [nn.Linear(width, 1)]
        self.net = nn.Sequential(*layers)

    def forward(self, t):
        return self.net(t)

model = MLP(width=64, depth=3).to(device)

# ---- unknown parameter k as trainable variable (use log to keep k>0) ----
log_k = nn.Parameter(torch.tensor(np.log(0.5), dtype=torch.float32, device=device))
def k_hat():
    return torch.exp(log_k)

# ---- losses ----
mse = nn.MSELoss()

def du_dt(u, t):
    # autograd derivative du/dt
    return torch.autograd.grad(
        u, t, grad_outputs=torch.ones_like(u),
        create_graph=True, retain_graph=True
    )[0]

# initial condition point
t0 = torch.tensor([[0.0]], dtype=torch.float32, device=device, requires_grad=True)
u0_target = torch.tensor([[u0]], dtype=torch.float32, device=device)

# weights
w_data = 1.0
w_phys = 1.0
w_ic   = 10.0  # ICを少し強めに(今回の例では安定しやすい)

opt = torch.optim.Adam(list(model.parameters()) + [log_k], lr=1e-3)

# ---- training loop ----
steps = 4000
hist = {"total":[], "data":[], "phys":[], "ic":[], "k":[]}

for step in range(1, steps+1):
    opt.zero_grad()

    # data loss
    u_pred_obs = model(t_obs_t)
    loss_data = mse(u_pred_obs, u_obs_t)

    # physics loss on collocation points
    u_f = model(t_f_t)
    u_t = du_dt(u_f, t_f_t)
    r = u_t + k_hat() * u_f
    loss_phys = torch.mean(r**2)

    # initial condition loss
    u_0 = model(t0)
    loss_ic = mse(u_0, u0_target)

    loss_total = w_data*loss_data + w_phys*loss_phys + w_ic*loss_ic
    loss_total.backward()
    opt.step()

    # logging
    hist["total"].append(loss_total.item())
    hist["data"].append(loss_data.item())
    hist["phys"].append(loss_phys.item())
    hist["ic"].append(loss_ic.item())
    hist["k"].append(k_hat().item())

    if step % 500 == 0:
        print(f"[PINN-inverse] step={step:4d} total={loss_total.item():.6f} "
              f"(data={loss_data.item():.6f}, phys={loss_phys.item():.6f}, ic={loss_ic.item():.6f}) "
              f"k_hat={k_hat().item():.4f}")

print("\nTrue k =", k_true)
print("Estimated k_hat =", hist["k"][-1])

# ---- plotting ----
# fig1: data
tt = np.linspace(t_min, t_max, 200)
plt.figure(figsize=(6.2,4.2))
plt.plot(tt, u_true_fn(tt), label="ground truth (only for this toy)")
plt.scatter(t_obs, u_obs, label="noisy observations")
plt.title("Toy data: exponential decay with unknown k")
plt.xlabel("t"); plt.ylabel("u(t)")
plt.grid(True); plt.legend()
plt.show()

# fig2: prediction
with torch.no_grad():
    tt_t = torch.tensor(tt, dtype=torch.float32, device=device).view(-1,1)
    u_pinn = model(tt_t).cpu().numpy().reshape(-1)

plt.figure(figsize=(6.2,4.2))
plt.plot(tt, u_true_fn(tt), label="ground truth (toy)")
plt.scatter(t_obs, u_obs, label="noisy observations")
plt.plot(tt, u_pinn, label=f"PINN inverse (k_hat={hist['k'][-1]:.3f})")
plt.title("Reconstruction + parameter estimation")
plt.xlabel("t"); plt.ylabel("u(t)")
plt.grid(True); plt.legend()
plt.show()

# fig3: k convergence
plt.figure(figsize=(6.2,4.2))
plt.plot(hist["k"], label="k_hat")
plt.axhline(k_true, linestyle="--", label="k_true (toy)")
plt.title("Estimated parameter k during training")
plt.xlabel("step"); plt.ylabel("k")
plt.grid(True); plt.legend()
plt.show()

# fig4: losses (log scale)
plt.figure(figsize=(6.2,4.2))
plt.semilogy(hist["total"], label="total")
plt.semilogy(hist["data"], label="data")
plt.semilogy(hist["phys"], label="phys")
plt.semilogy(hist["ic"], label="ic")
plt.title("Loss curves (log scale)")
plt.xlabel("step"); plt.ylabel("loss")
plt.grid(True); plt.legend()
plt.show()

# fig5: residual on a grid
t_eval = torch.tensor(tt, dtype=torch.float32, device=device).view(-1,1)
t_eval.requires_grad_(True)
u_eval = model(t_eval)
u_eval_t = du_dt(u_eval, t_eval)
r_eval = (u_eval_t + k_hat() * u_eval).detach().cpu().numpy().reshape(-1)

plt.figure(figsize=(6.2,4.2))
plt.plot(tt, r_eval, label="residual r(t)=u'(t)+k u(t)")
plt.axhline(0.0, linestyle="--")
plt.title("Physics residual (should be near 0)")
plt.xlabel("t"); plt.ylabel("residual")
plt.grid(True); plt.legend()
plt.show()

6. 得られる結果の解説

学習が進むほど k_hat が増え、最終的に

  • True k = 1.5
  • Estimated k_hat ≈ 1.4255

になります(差は約0.074、相対で約5%)。

また、損失の内訳は終盤でだいたい:

  • total ≈ 4.0e-4
  • data ≈ 3.78e-4(支配的)
  • phys ≈ 2.2e-5(十分小さい)
  • ic ≈ 0(ほぼ満たせている)

という状態になります。


6.1 図1(データ):ノイズがあると「完全一致」より「妥当な推定」になる

fig1_data(1).png

観測点は真の指数曲線の近くにありますが、ノイズで上下にぶれています。

この時点で重要なのは:

  • データ損失は、ノイズがある限り 0 にはならない
  • したがって、k の推定も「真値に完全一致」ではなく
    ノイズと物理制約の折衷点 になりやすい

ということです。


6.2 図2(再構成):k_hatが小さいと減衰が少し遅く見える

fig2_prediction.png

指数減衰 u(t)=exp(-k t) では、

  • k が小さいほど減衰が遅い(後半で値が高めになる)

なので、k_hat=1.426 < 1.5 の場合、推定曲線は(特に後半で)真の曲線より少し上に出やすくなります。

ただし観測点がノイズで散っているため、
「真値から少しズレたkでもデータには十分合う」状況になり得ます。


6.3 図3(k収束):単調に近づいているが、最後に少し頭打ち

fig3_k_convergence.png

実行結果のログを見ると:

  • step 500: 0.692
  • step 2000: 1.250
  • step 3000: 1.387
  • step 4000: 1.426

と、かなり素直に収束しているはずです。

一方で、後半(3000→4000)は伸びが小さいので、

  • すでに 損失がほぼ頭打ち
  • k を動かしても total があまり改善しない

という「最適化の終盤」に入っていると考えられます。


6.4 図4(損失カーブ):最後は data loss が支配的=「ノイズの壁」

fig4_losses.png

終盤で datatotal の大半を占めているので、

  • これ以上 total を下げたくても、観測ノイズが邪魔をしている
  • physic は既に十分小さいため、改善余地が少ない

という状態です。

逆問題PINNではありがちな形で、

  • 物理は満たせる
  • でもデータはノイズで完全一致できない
  • その折衷として k_hat が少しバイアスする

という構図が見えます。


6.5 図5(残差):理想は0、現実には「小さいがゼロではない」

fig5_residual.png

残差図は、±0.01程度の振れが見えます。

  • “理想は0だが、有限ステップ・ノイズ・重み付けのトレードオフで完全には0にならない。
    ただし十分小さければ、方程式整合性は概ね確保できている”

phys loss ≈ 2e-5 は小さいので、
“だいたい物理を満たせている”と言ってよい。


7. 考察:なぜ k_hat は 1.5 に完全一致しなかった?

今回の結果(k_hat≈1.426)を踏まえると、主に次の可能性が高いです。

7.1 観測ノイズによるバイアス(「厳密解」より「ノイズに合う解」)

データ損失が支配的になると、PINNは

  • “方程式を完全に満たす”より
  • “ノイズ込みの観測に合う”

方向へ引っ張られます。
このとき u(t) の形と k はトレードオフになり、k が少しズレることがあります。

7.2 同時推定ゆえの自由度(uとkが“分担”できる)

今回は u(t) をNNで表現しているので、指数関数以外の形も表現可能です。

  • u(t) を少し歪めることで
  • k を真値からずらしても
  • データ+物理残差の合計が小さくできる

…という「逃げ道」が存在します(これが“逆問題の難しさ”でもあります)。

7.3 重み付け(w_phys, w_data, w_ic)のバランス

終盤の損失を見ると data が強いので、

  • 物理をもっと強くしたいなら w_phys を上げる

という改善の余地があります。
ただし上げすぎると逆に「データを無視して物理だけ満たす」方向へ行き、
観測と合わなくなるので調整が必要です。

7.4 最適化(Adamだけだと“あと一歩”で止まることがある)

PINNは損失地形が厳しいことがあり、

  • Adamで荒く近づく
  • L-BFGSで仕上げる

の2段階で良くなるケースがあります。


8. 改善案(kをもっと真値に寄せたいとき)

以下は「玩具例で真値が分かっているから言える」改善です。
実データでは真値が分からないので、残差・再現性・外部検証で妥当性を確認してください。

  • 観測点を増やす / 観測区間を広げる
    k は時間スケールの情報なので、終盤の点が増えると識別しやすいです。

  • ノイズを下げる(または損失をノイズ分散でスケールする)
    ノイズが大きいほど、data loss が頭打ちになって k がブレます。

  • w_phys を少し上げる
    例:w_phys=510 を試して、k_hat の安定性とデータ適合のバランスを見る。

  • Adam→L-BFGSの2段階最適化
    仕上げに強い場合があります。

  • 初期条件を“ハード制約”にする(ic lossを消せる)
    例えば u(t)=u0 + t * NN(t) とすると、u(0)=u0 を必ず満たせます。
    icの学習が不安定な場合に効きます(今回icは十分小さいですが、一般論として有用)。

  • 多初期値(multi-start)でkを試す
    k の初期値により収束先が変わることがあります。


9. 実務にどう効く?(教育・研究・実験の文脈)

この最小例は「指数減衰」ですが、実務では例えば:

  • 薬物動態:血中濃度の減衰(消失速度定数)
  • 生体信号:回復過程の時定数推定
  • センサー:温度応答・緩和モデルの係数推定
  • 材料・工学:拡散係数、熱伝導率、反応速度の推定

のように、モデルを信じたいが係数が不明 という場面で「逆問題PINN」の発想が刺さります。


10. まとめ

  • 逆問題PINNでは u(t) と k を同時に学習できる
  • 本記事の結果は、損失内訳・残差・収束の形が素直で、学習は概ね成功している
  • k_hat が真値より少し低いのは、ノイズと損失バランスの影響として自然
  • 残差は、“理想は0だが、ノイズやトレードオフで完全一致しないことがある”
  • 精度を上げたいときは、重み付け・データ量・最適化(Adam→LBFGS)などの調整が有効
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?