0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DeepONet入門:仕組み・実装例まとめ

Last updated at Posted at 2025-05-10

タイトル

更新履歴

  • 2025/05/11:初版

はじめに

ハードウェアの進歩に伴い、データ駆動型アプローチがトレンドとなっている。しかし、現実世界や実際のエンジニアリング現場では、観測データが限られていたり(たとえば衝突現象)、予測モデルに外挿性が求められるケースが散見される。特に、物理法則に基づくモデリングの重要性は依然として高い。

作用素学習 は、関数空間間の写像を学習する技術であり、物理現象の高速シミュレーションや複雑なデータ解析の効率化を目的として開発された。これは、従来の数値計算手法では困難だった複雑系のモデリングに対して、新たなパラダイムシフトをもたらしつつある。

その代表手法として DeepONetFourier Neural Operator がよく挙げられるが、本稿では前者である DeepONet に焦点を当てる。元論文はこちら

ちなみに、2025年5月11日現在Qiita上でDeepONetをはじめとする作用素学習を主題にした記事はないので、本稿は初となるっぽいのである。

作用素学習

DeepONetの具体的な手法説明の前に、作用素学習について簡単に解説する。

通常の機械学習では、関数 $f:\mathbb{R}^n\mapsto\mathbb{R}^m$ を学習するが、作用素学習では、次の様に関数空間 $\mathcal{A}$ から関数空間 $\mathcal{U}$ への写像(作用素)$\mathcal{G}$ を学習する:

\mathcal{G}:\mathcal{A}\mapsto\mathcal{U}
\tag{1}

例えば、次のような偏微分方程式(以下、PDE)を考えたとき:

\mathcal{L}_au(\boldsymbol{x})
=
f(\boldsymbol{x})
\tag{2}

この解 $u$ をパラメータ $a(\boldsymbol{x})$ に対する作用素 $\mathcal{G}:a\mapsto u$ と考えて学習する。

具体例:1次元ポアソン方程式

次のような2階偏微分方程式を考える:

-
\cfrac{
    d^2u
}{
    dx^2
}
=
a(x)
,\quad
x\in[0,1]
,\quad
u(0)=u(1)=0
\tag{3}

このとき作用素 $\mathcal{G}:a(x)\mapsto u(x)$ を学習する。
例えば、入力関数が $a(x)=\sin(\pi x)$ のとき、開区間 $x\in[0,1]$ を適当に($n=5$)で均等に離散化すると:

x
=
[
    0,\,
    1/4,\,
    1/2,\,
    3/4,\,
    1
]
\tag{4}

このときの入力ベクトル $a=[a(x_1),\cdots,a(x_5)]$ は:

\begin{align}
a
&=
[
    \sin{0},\,
    \sin{\pi/4},\,
    \sin{\pi/2},\,
    \sin{3\pi/4},\,
    \sin{\pi}
]
\tag{5}
\\
&\simeq
[
    0,\,
    0.7071,\,
    1,\,
    0.7071,\,
    0
]
\tag{6}
\end{align}

そしてこれに対する出力ベクトル $u=[u(x_1),\cdots,u(x_5)]$ は:

\begin{align}
u
&=
[
    \frac{1}{\pi^2}\sin{0},\,
    \frac{1}{\pi^2}\sin{\pi/4},\,
    \frac{1}{\pi^2}\sin{\pi/2},\,
    \frac{1}{\pi^2}\sin{3\pi/4},\,
    \frac{1}{\pi^2}\sin{\pi}
]
\tag{7}
\\
&\simeq
[
    0,\,
    0.0716,\,
    0.1013,\,
    0.0716,\,
    0
]
\tag{8}
\end{align}

故に、これを1ペアとして:

  • 入力:$[
    0,,
    0.7071,,
    1,,
    0.7071,,
    0
    ]$
  • 出力:$[
    0,,
    0.0716,,
    0.1013,,
    0.0716,,
    0
    ]$

これを1つの訓練データとして扱う。

コンピュータサイエンスや数値解析では、関数を直接扱うことはできず、離散化してベクトルや行列で近似・表現する。

つまり、関数 $u(\boldsymbol{x})\simeq$ ベクトル $\boldsymbol{u}$ として扱う。

DeepONetと普遍近似定理

ニューラルネットワークの普遍近似能力は、従来型の有限次元写像の場合と、関数空間間の写像を対象とする場合とで本質的に異なる。ここでは、通常の多層パーセプトロン(MLP)に対する普遍近似定理と、DeepONet における作用素の普遍近似定理を、数式モデルに基づいて比較する。

通常の順伝播型ニューラルネットワークにおける普遍近似定理

対象は関数:

f:\mathbb{R}^m\mapsto\mathbb{R}^n\tag{9}

である。古典的な普遍近似定理は次を主張する。

定理(順伝播NNの普遍近似)

任意のコンパクト集合 $K\subset\mathbb{R}^n$ 上の連続関数 $f\in C(K, \mathbb{R}^m)$ は有限個のノードを持つ1層のニューラルネットワーク $\hat{f}$ が存在して:

\sup_{\boldsymbol{x}\in K}
\|
    f(
        \boldsymbol{x}
    )
    -
    \hat{f}
    (
        \boldsymbol{x}
    )
\|
<
\varepsilon
\tag{10}

が成立する。
より具体的に、ネットワーク構造は次の形で表現できる:

\hat{f}
(
    \boldsymbol{x}
)
=
\sum_{k=1}^p
\boldsymbol{c}_k
\sigma
\left(
    \langle
      \boldsymbol{w}_k,
      \boldsymbol{x}
    \rangle
    +
    \zeta_k
\right)
\tag{11}

ここで、$\sigma (\cdot)$ は活性化関数、$\boldsymbol{c}_i\in\mathbb{R}^m,\ \boldsymbol{w}_i\in\mathbb{R}^n,\ \zeta_i\in\mathbb{R}$ は学習パラメータである。

順伝播型ニューラルネットワークは、有限次元空間間の連続関数を任意精度で近似できる。

DeepONet における普遍近似定理

DeepONet では、対象が関数そのものから別の関数への写像、すなわち作用素である。対象は:

\mathcal{G}:
V\subset C(K_1)
\mapsto
C(K_2)
\tag{12}

という非線形連続作用素である。
ここで、$V$ はあるコンパクト集合、$K_1,,K_2$ は有限次元空間上のコンパクト集合であり、$C(K)$ は $K$ 上の連続関数空間を表す。

この時、T.Chen et al., 1995が証明した普遍近似定理をベースに、DeepONetの設計に対応する定理は以下となる。

定理(DeepONetにおける普遍近似)

任意の $\mathcal{G}\in C(V, C(K_2))$ と $\varepsilon>0$ に対し、あるブランチネットワーク及びトランクネットワークのパラメータを適切に取ることで、DeepONet $\hat{\mathcal{G}}$ が存在して:

\sup_{u\in V,\boldsymbol{y}\in K_2}
\|
    \mathcal{G}(u)(\boldsymbol{y})
    -
    \hat{\mathcal{G}}(u)(\boldsymbol{y})
\|
<
\varepsilon
\tag{13}

が成立する。

DeepONetによる近似は、以下のような線形結合として構成される:

\hat{\mathcal{G}}
(u)(\boldsymbol{y})
=
\sum_{k=1}^p
b_k(u)\,
t_k(\boldsymbol{y})
\tag{14}

ネットワーク構造

式 $(14)$ が示す通り、DeepONetは $b_k$:ブランチネット(Branch Net)$t_k$:トランクネット(Trunk Net) の2種類のネットワークを有する。

ブランチネット

  • 入力:関数 $u$ の値ベクトル:

    \boldsymbol{u}
    =
    \left[
        u(\boldsymbol{x}_1),
        \cdots,
        u(\boldsymbol{x}_m)
    \right]
    ^\top
    \in
    \mathbb{R}^m
    \tag{15}
    
  • 出力:ベクトル $\boldsymbol{b}\in\mathbb{R}^p$

  • 構造:MLP

  • 役割:入力関数 $u$ の符号化

  • 数式モデル:

    \begin{align}
    b_k
    &=
    \sum_{i=1}^n
    c_{k,i}\,
    \sigma
    \left(
        \sum_{j=1}^m
        \xi_{k,i,j}
        u(\boldsymbol{x}_j)
        +
        \theta_{k,i}
    \right)
    \tag{16}
    \\
    &=
    \sum_{i=1}^n
    c_{k,i}\,
    \sigma
    \left(
        \langle
    \boldsymbol{\xi}_{k,i},
            u(
                \boldsymbol{x}
            )
        \rangle
        +
        \boldsymbol{\theta}_k
    \right)
    \tag{17}
    \\
    &=
    \langle
        \boldsymbol{c}_k,
        \sigma
        \left(
            \langle
                \boldsymbol{\Xi}_{k},
                u(
                    \boldsymbol{x}
                )
            \rangle
            +
            \boldsymbol{\theta}_k
        \right)
    \rangle
    \tag{18}
    \end{align}
    

    これを見て分かる様に、出力 $b_k$ に対して個別にブランチネットが構成 されている。これを Stacked型 と呼ぶ。成分ごとに別々のネットワークで計算しているだけあり、出力表現力は高いが、計算コストが高い。

    これに対し、ひとつのブランチネットが全ての $b_k$ を一括で出力するUnstacked型 というものがある:

    b_k
    =
    \langle
        \boldsymbol{c}_k,
        \sigma
        \left(
            \langle
                \boldsymbol{\Xi},
                u(
                    \boldsymbol{x}
                )
            \rangle
            +
            \boldsymbol{\theta}
        \right)
    \rangle
    \tag{19}
    

    Unstacked型では、パラメータは共有されており、計算/メモリ効率も良い。しかしながら、当然表現力は低下するため、複雑な作用素に対しては不利となる。

    [論文抜粋画像]

    image.png

トランクネット

  • 入力:評価点 $\boldsymbol{y}\in\mathbb{R}^d$

  • 出力:ベクトル $\boldsymbol{t}\in\mathbb{R}^p$

  • 構造:MLP

  • 役割:出力座標 $y$ に対応する空間的基底関数の重みベクトル

  • 数式モデル:

    \begin{align}
        t_k
        =
        \sigma
        \left(
            \langle
                \boldsymbol{w}_k,
                \boldsymbol{y}
            \rangle
            +
            \zeta_k
        \right)
    \tag{20}
    \end{align}
    

データ生成方法

DeepONetは作用素を学習するため、入力関数 $u(\boldsymbol{x})$ の表現力が予測精度の上限を決定する。ここでは、入力関数 $u$ をどのように生成し、データセット $\left( u,\boldsymbol{y},\mathcal{G}(u)(\boldsymbol{y})\right)$ を構成するかについて説明する。
DeepONet論文内では、入力関数を次の2種類の関数空間からサンプリングすることを提案している。

ガウス過程(GRF: Gaussian Random Field)

入力関数 $u$ を平均 $0$ のGRFとして生成する:

\begin{align}
u
\sim
\mathcal{GP}
\left(
    0,\,
    k_l(x_1,x_2)
\right)
\tag{21}
\\
k_l(x_1,x_2)
=
e^{
    -
    \frac
    {
        \|x_1-x_2\|^2
    }{
        2l^2
    }
}
\tag{22}
\end{align}
  • $l>0$:RBFカーネルの自由パラメータ
  • $l$ が大きい→滑らか
  • $l$ が小さい→高周波成分を含む関数

GRFは確率的で多様な入力関数を生成するため、動的システムの様々な挙動を学習させるのに適している。

チェビシェフ多項式(Chebyshev Polynomial)

V_{\rm poly}
=
\left\{
    \left.
        \sum_{i=0}^{N-1}
        a_iT_i(x)
    \right|
    \|a_i\|\leq M
\right\}
\tag{23}
  • $T_i(x)$:第一種チェビシェフ多項式
  • $a_i\in\left[-M, M\right]$:係数は一様分布で生成

この関数空間では、$u$ は決定的に選択されるが、制御された形状(=関数が有次元ベクトルで明示的に決まる)を持ち、解析しやすい利点を持つ。

データ生成フロー

  • Step.1:入力関数のサンプルを生成(ガウス過程 or Chebyshev多項式)
  • Step.2:ODEないしPDEを数値的に解く
  • Step.3:データ点、トリプレットの構成
    • ひとつの入力関数 $u$ に対して複数の評価点 $y$ で出力 $\mathcal{G}(u)(\boldsymbol{y})$ を記録
    • 各データ点は三つの組 $(u,\boldsymbol{y},\mathcal{G}(u)(\boldsymbol{y}))$ で構成

簡単な動的システムのデータ生成例(ODE編)

次のODEを考える:

\frac{
    ds
}{
    dx
}
=
-s(x)
+u(x)
,\quad
s(0)=0
\tag{24}
  • $s(x)$:状態(出力)
  • $u(x)$:入力関数(センサー観測)
  • 区間:$x\in[0,1]$

まず、このODEの解析解は次の様になる:

\begin{align}
&&
\frac{ds}{dx}
=
-s(x)
+u(x)
\tag{25}
\\
&\Rightarrow&
\frac{ds}{dx}
+s(x)
=
u(x)
\tag{26}
\end{align}

積分因子は $e^x$ なので、これを両辺に掛けて:

\begin{align}
&&
e^x\frac{ds}{dx}
+e^xs(x)
=e^xu(x)
\tag{27}
\\
&\Rightarrow&
\frac{d}{dx}
\left(
    e^xs(x)
\right)
=e^xu(x)
\tag{28}
\\
&\Rightarrow&
e^xs(x)
=
\int_0^x
e^tu(t)
\,dt+C
\tag{29}
\end{align}

初期条件 $s(0)=0$ から積分定数 $C=0$

\begin{align}
s(x)
&=
e^{-x}
\int_0^x
e^tu(t)\,dt
\tag{30}
\\
&=\int_0^x
e^{-(x-t)}u(t)\,dt
\tag{31}
\end{align}

よって、非線形作用素 $\mathcal{G}$ が入力関数 $u$ を受け取って、出力関数 $s=\mathcal{G}(u)$ を返すことになる。

Step.1 入力関数のサンプルを生成(GRFベース)

どちらでも良いが、今回はGRFを用いて生成する:

import numpy as np

x = np.linspace(0, 1, 101)


def rbf_kernel(x, l=0.2):
    dists = np.subtract.outer(x, x) ** 2
    return np.exp(-dists / (2 * l**2))


l = 0.1
cov = rbf_kernel(x, l)
mean = np.zeros_like(x)

u = np.random.multivariate_normal(mean, cov)
Step.2 ODEをルンゲ=クッタ法で数値的に解く

入力 $u$ を固定した上で、以下のようにルンゲ=クッタ法でODEを解く:

from scipy.integrate import solve_ivp
from scipy.interpolate import interp1d


def ode_rhs(x, s, u_interp):
    return -s + u_interp(x)


u_interp = interp1d(x, u, kind="linear")
sol = solve_ivp(
    fun=lambda x, s: ode_rhs(x, s, u_interp), t_span=[0, 1], y0=[0], t_eval=x
)

s = sol.y[0]
Step.3 データ点、トリプレットの構成
  • $\boldsymbol{u}$:101点の関数サンプル
  • $\boldsymbol{y}$:関数出力の評価点
  • $s(\boldsymbol{y})$:上で求めた解の対応値

イメージ:

入力関数:$u$ 評価点:$\boldsymbol{y}$ 出力:$s(\boldsymbol{y})=\mathcal{G}(u)(\boldsymbol{y})$
長さ101のベクトル $0$ $s(0)$
同上 $0.01$ $s(0.01)$
$\vdots$ $\vdots$ $\vdots$
同上 $1$ $s(1)$
U, Y, S = [], [], []

s_interp = interp1d(x, s, kind='cubic')

for y in np.linspace(0, 1, 101):
    U.append(u.copy())
    Y.append([y])
    S.append([s_interp(y)])


data = (U, Y, S)

以上で、1つの入力関数 $u$ に対して、多点 $y\in[0,1]$ での出力 $s(y)$ を求めてデータセットを作るコード になる。

ここで、kind='cubic' としたのは、ODE解から得られた $s$ の補間には滑らかさが重要だと考えたからである。

簡単な動的システムのデータ生成例(PDE編)

2つ目の例として、1次元拡散方程式を考える:

\begin{align}
\cfrac
{
    \partial s
}{
    \partial t
}
=
\cfrac
{
    \partial^2 s
}{
    \partial x^2
}
,\quad
x\in[0,1]
\tag{32}
,\quad
t\in(0,1]
\end{align}
  • 境界条件:$s(0,t)=s(1,t)=0$
  • 初期条件(入力関数):$s(x,0)=u(x)$
Step.1 入力関数のサンプルを生成(Chebyshevベース)

今回はチェビシェフ多項式を用いて入力関数を生成してみる:

import numpy as np
from numpy.polynomial.chebyshev import Chebyshev

x = np.linspace(0, 1, 101)


def generate_u(x, degree=5):
    coeffs = np.random.uniform(-1, 1, size=degree)
    T = Chebyshev(coeffs, domain=[0, 1])
    return T(x)


u = generate_u(x)
Step.2 拡散方程式をFTCS法で数値的に解く

入力 $u$ を固定して、FTCS法で解く:

def solve_diff1d(u, x, ts):
    nx = x.shape[0]
    nt = ts.shape[0]
    dx = x[1] - x[0]
    dt = ts[1] - ts[0]
    alpha = dt / dx**2

    # Initial and boundary conditions
    s = np.zeros((nx, nt))
    s[:, 0] = u
    s[0, :] = 0
    s[-1, :] = 0

    # FTCS: Forward Time Centered Space
    for t in range(0, nt - 1):
        for i in range(1, nx - 1):
            s[i, t + 1] = s[i, t] + alpha * (s[i + 1, t] - 2 * s[i, t] + s[i - 1, t])

    return s


ts = np.linspace(0, 1, 51)
s = solve_diff1d(u, x, ts)
Step.3
入力関数:$u$ 評価点:$\boldsymbol{y}=(x, t)$ 出力:$s(\boldsymbol{y})=\mathcal{G}(u)(\boldsymbol{y})$
長さ101のベクトル $(0, 0)$ $s(0, 0)$
同上 $(0.01, 0)$ $s(0.01, 0)$
$\vdots$ $\vdots$ $\vdots$
同上 $(1,1)$ $s(1,1)$
xx, tt = np.meshgrid(x, ts)
Y = np.stack([xx.ravel(), tt.ravel()], axis=1)
S = s.flatten()
U = np.tile(u, (len(S), 1))

data = (U, Y, S)

入力関数のサンプリング点数

本章では、非線形動的システムの同定に必要なセンサー数(入力関数のサンプリング点数)を、理論と数値の両面から検討する。

問題設定

以下の様な非線形力学系を考える:

\begin{cases}
\cfrac{
    d
}{
    dx
}
\boldsymbol{s}(x)
=
\boldsymbol{g}
\left(
    u(x),\,
    x,\,
    \boldsymbol{s}(x)
\right)
\\
\boldsymbol{s}(a)
=
\boldsymbol{s_0}
\end{cases}
\tag{33}
  • $u(x)\in V\subset C[a,b]$:入力関数
  • $s(x)\in\mathbb{R}^K$:出力関数

ここで、入力 $u$ に対して解 $s$ を返す作用素を $\mathcal{G}$ とすると、$\mathcal{G}$ は次の様な積分形式で表現される:

\mathcal{G}(u)(x)
=
\boldsymbol{s}_0
+
\int_a^x
\boldsymbol{g}
\left(
    u(t),\,
    t,\,
    \mathcal{G}(u)(x)
\right)
\,
dt
\tag{34}

入力関数の近似(サンプリング法)

センサーで $m+1$ 点を一様分割し、入力関数 $u(x)$ を線形補完で近似したものを $u_m(x)$ と定義:

u_m(x)
=
u(x_j)
+
\cfrac
{
    \Delta u
}{
    \Delta x
}
\,
(
    x-x_j
)
\tag{35}
  • $x_j=a+j(b-a)/m$:区間 $[a,b]$ を $m$ 分割した時の $j$ 番目の点
  • $\Delta x=x_{j+1}-x_j=1/m$
  • $\Delta u=u(x_{j+1})-u(x_j)$

この時、近似誤差の上限は次のように表される:

\max_{
    x\in[a,b]
}
\|
    u(x)
    -
    u_m(x)
\|
\leq
\kappa(
    m, V
)
\tag{36}

ここで、$\kappa(m, V)$ は関数空間 $V$ に属する関数に対して、サンプル数 $m$ で得られる最大の近似誤差を表す。

例えば、$u$ がRBFカーネルで定義されたGRF上の関数である時、それは $C^2$ 級関数であるので:

\|
    u(x)
    -
    u_m(x)
\|
\leq
C
\cdot
\frac{
    (b-a)^2
}{
    m^2
}
\cdot
\max_{x\in[a,b]}
|
    u^{\prime\prime}(x)
|
\tag{37}

$|u^{\prime\prime}(x)|^2$ の期待値のオーダーに着目すると:

\begin{align}
&&
\mathbb{E}
\left[
    |
        u^{\prime\prime}
        (x)
    |^2
\right]
=
\mathcal{O}
\left(
    \frac{1}{l^4}
\right)
\tag{38}
\\
&\Rightarrow&
\max
|
    u^{\prime\prime}(x)
|
=\mathcal{O}
\left(
    \frac{1}{l^2}
\right)
\tag{39}
\end{align}

よって:

\max_{
    x\in[a,b]
}
\|
    u(x)
    -
    u_m(x)
\|
\leq
\mathcal{
    O
}
\left(
    \frac{1}{m^2l^2}
\right)
\tag{40}

センサー数に対する誤差理論

以下の条件を満たす様なセンサー数 $m$ を選べば、近似誤差を任意の精度 $\varepsilon$ 以下に抑えられる:

c(b-a)\kappa(m,V)e^{c(b-a)}\leq\varepsilon
\tag{41}

$c$ は関数 $g$ のリプシッツ定数である。故に、必要なセンサー数 $m$ は次の様に求められる:

m\geq\frac{1}{l}
\tag{42}

即ち、入力関数 $u$ のパラメータ $l$ が小さい(=荒い関数)ほど、多くの千数が必要になる。

センサー数と誤差の関係

  • 小さい $m$:誤差は指数関数的に減少:

    \text{MSE}\propto 4.6^{-m}\tag{43}
    
  • 大きい $m$:10点程度で飽和現象が見られる

  • 時間予測長 $T$ が増えると、必要なセンサー数も増える:

    m\propto\sqrt{T}\tag{44}
    
  • 入力関数が粗い($l$ が小さい)ほど、必要なセンサー数が増加:

    m\propto l^{-1}\tag{45}
    

DeepONetの実装例

ここからは、実際に PyTorch を用いて DeepONet による作用素学習を行う。
例として、前節で扱った式 $(24)$ の常微分方程式を取り上げ、以下の2つの設定でモデル精度を比較する:

  • サンプリング点数 $m$ およびスケールパラメータ $l$ ごとの MSE 比較
  • DeepONet と FNN の性能比較

なお、両者の比較は同一の関数(データセット)を用いて実施し、$m$ や $l$ に依存しないよう工夫する。また、乱数シードを固定することで再現性を担保したデータセット生成を行う。

計算環境

  • インスタンス:g4dn.xlarge
  • AMI:ami-0a8ee7e764943feef
  • GPU:NVIDIA T4 Tensor Core GPU
  • CPU:4 vCPU Intel Xeon
  • RAM:16GB
  • ストレージ:EBS SSD 128GB
  • Python:3.11.8

requirements.txt

# Core
numpy>=1.26
scipy>=1.11
matplotlib>=3.7

# PyTorch + Lightning
torch>=2.2
pytorch-lightning>=2.2
torchmetrics>=1.0
tensorboard>=2.14
tensorboardX>=2.6

# Optional
rich>=13.0
tqdm>=4.60

モジュール

論文中にはGithubが公開されており、examplesの中にある、不定積分作用素学習:antiderivative_aligned_pideeponet.py あたりが参考になりそうだったが、ddeライブラリの仕様や使い方がよくわからなかったので、勉強も兼ねてスクラッチすることにした。私は PyTorch 派閥なので、PyTorch Lightning の作法に倣う。

.
├── dataset
│   ├── __init__.py
│   ├── deeponet_dataset.py
│   └── mlp_dataset.py
├── function_space
│   ├── __init__.py
│   └── grf.py
├── main.py
├── main2.py
├── model
│   ├── __init__.py
│   ├── mlp.py
│   └── unstacked_deeponet.py
├── plot_mse.py
├── plot_vs_mlp.py
├── requirements.txt
└── utils
    ├── __init__.py
    └── utils.py

5 directories, 15 files

GRF クラス

GRF クラスは、入力関数 $u(x)$ を生成するための関数ジェネレータとして設計されている。
従来はルンゲクッタ法により任意の常微分方程式(ODE)の右辺(RHS)を定義し、数値解を求めていたが、これは計算時間がかかるため非効率である。

ここでは、式 $(24)$ のODEの解が式 $(31)$ のように解析的に積分表示で与えられることを利用し、入力関数 $\boldsymbol{u}$ に積分カーネルを畳み込んだ出力関数 $\boldsymbol{s}$ を高速に生成する。

このクラスでは、あらかじめ高解像度の離散点 $x_{\text{dense}}$ を定義し、それに基づいた積分行列(下三角の指数カーネル)を構築する。
random メソッドにおいては、指定された点数 $m$ に応じて $x_{\text{dense}}$ から等間隔に点を抽出し、対応する $\boldsymbol{u}$ および $\boldsymbol{s}$ のベクトルを返す。

この実装により、ODEを逐一数値解くことなく、作用素学習のための大規模データセットが効率的に構築できる:

Step.1

はじめに、与えられた区間 span において密な等間隔グリッドを定義する:

\begin{align}
\begin{cases}
\lbrace x_j\rbrace_{j=0}^{n-1}\in\text{span}
\\
\Delta x:=x_j-x_{j-1}
\\
n:=\text{n_dense}
\end{cases}
\tag{46}
\end{align}
Step.2

入力関数 $\boldsymbol{u}$ は平均 $0$ の共分散関数:

\mathbb{E}
\left[
    u(x_1)
    u(x_2)
\right]
=
e^{
    -
    \frac
    {
        \|x_1-x_2\|^2
    }{
        2l^2
    }
}
\tag{47}

に従う多変量正規分布からサンプリングされる。この離散化により:

\begin{align}
u^{(k)}
\sim
\mathcal{GP}
\left(
    0,\,
    k_l(x_i,x_j)
\right)
\tag{48}
\\
\boldsymbol{K}_{ij}
=
e^{
    -
    \frac
    {
        \|x_i-x_j\|^2
    }{
        2l^2
    }
}
\tag{49}
\end{align}

として、任意の関数本数 $k=1,\cdots,\text{n_funcs}$ にわたる $\boldsymbol{u}^{(k)}\in\mathbb{R}^n$ が得られる。

Step.3

解 $\boldsymbol{s}$ を離散的に生成するために、以下の積分演算を下三角行列 np.tril によって離散化する:

s(x_i)
=
\int_o^{x_i}
e^
{
    -
    \left(
        x_i-t
    \right)
}
u(t)\,dt
\simeq
\sum_{j=0}^i
e^
{
    -
    \left(
        x_i-t
    \right)
}
u(x_j)\Delta x
\tag{50}

これの行列表現は:

\begin{align}
\boldsymbol{s}
=
\boldsymbol{K}_{\rm tril}
\boldsymbol{u}
,\quad
\left(
    \boldsymbol{K}_{\rm tril}
\right)_{ij}
=
\begin{cases}
    e^{-(x_i-x_j)}\Delta x,\quad&\text{if }j\leq i\\
    0&\text{otherwise}
\end{cases}
\tag{51}
\end{align}
Step.4

任意のサンプリング点数 $m$ が与えられた時、以下の様に等間隔でサンプリングする:

x^{(m)}
=
\left\{
    x_0^{(m)},
    x_1^{(m)},
    \cdots,
    x_{m-1}^{(m)}
\right\}
\subset
\left\{
    x_0,
    x_1,
    \cdots,
    x_{n-1}
\right\}
\tag{52}

これは np.searchsorted によって行われる。
以上の操作により観測点における $\boldsymbol{u},,\boldsymbol{s}$ のペアが得られる。

dataset/grf.py
import numpy as np


class GRF:
    def __init__(self, span=(0.0, 1.0), n_dense=1e3):
        self.span = span
        self.n_dense = int(n_dense)
        self.x_dense = np.linspace(*span, self.n_dense)
        self.dx = np.median(np.diff(self.x_dense))
        self.kernel = self._build_kernel(self.x_dense, self.dx)

    def _build_kernel(self, x, dx):
        xx, tt = np.meshgrid(x, x, indexing="ij")
        K = np.exp(-(xx - tt)) * dx
        return np.tril(K)

    def _generate_grf(self, ell, n_funcs, seed):
        D2 = np.subtract.outer(self.x_dense, self.x_dense) ** 2
        K = np.exp(-D2 / (2 * ell**2))
        mean = np.zeros(self.n_dense)
        rng = np.random.default_rng(seed)
        return rng.multivariate_normal(mean, K, size=n_funcs)

    def random(self, m, ell, n_funcs, seed=0):
        U_dense = self._generate_grf(ell, n_funcs, seed)
        S_dense = U_dense @ self.kernel.T
        x_m = np.linspace(*self.span, m)
        idx = np.searchsorted(self.x_dense, x_m)
        return U_dense[:, idx], S_dense[:, idx], x_m

DeepONetDataset クラス

このクラスは torch.utils.data.Dataset を継承し、DeepONet の訓練に適したデータセットを構築する。
PyTorch の DataLoader に渡してバッチ学習を行うことが可能である。

具体的には、入力関数ベクトル $\boldsymbol{u} \in \mathbb{R}^m$ を与えたときに、任意の出力点 $y \in [0,1]$ における作用素出力 $\boldsymbol{s}(y)$ を、補間を通じて評価しデータを構築する。

つまり、1つの関数 $u$ に対して複数の点 $y_1, \dots, y_k$ を与えることで、以下のような学習用データペアを得る:

\left\{ \left( \boldsymbol{u}, y_j, s(y_j) \right) \right\}_{j=1}^{k}
\tag{53}

ここで:

  • $\boldsymbol{u}$:関数 $\boldsymbol{u}(x)$ を $m$ 点で観測したベクトル
  • $y_j$:出力の評価点(トランクネットへの入力)
  • $\boldsymbol{s}(y_j)$:解析的または補間的に得られる関数の出力

本クラスでは、入力関数を標準化(平均0・分散1)してからデータ構造を構築しており、学習の安定性を高めている。

dataset/deeponet_dataset.py
import numpy as np
from scipy.interpolate import interp1d
import torch
from torch.utils.data import Dataset


class DeepONetDataset(Dataset):
    def __init__(self, U, S, x, eval_pts, normalize_stats=None):
        super().__init__()
        self.data = []
        self.u_mean, self.u_std = normalize_stats or self._compute_stats(U)
        for u, s in zip(U, S):
            s_interp = interp1d(x, s, kind="cubic")
            u_norm = (u - self.u_mean) / self.u_std
            for y in eval_pts:
                self.data.append(
                    (
                        torch.tensor(u_norm, dtype=torch.float32),
                        torch.tensor([y], dtype=torch.float32),
                        torch.tensor([float(s_interp(y))], dtype=torch.float32),
                    )
                )

    def _compute_stats(self, U):
        u_all = np.stack(U)
        return u_all.mean(axis=0), u_all.std(axis=0) + 1e-6

    def get_normalization_stats(self):
        return self.u_mean, self.u_std

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

UnstackedDeepONetクラス

このクラスは、PyTorch Lightning を用いた DeepONet のシンプルな実装であり、BranchNet と TrunkNet を組み合わせて出力を生成する。
本クラスの特徴は、モデル構造が2つの独立な MLP により構成され、それらの出力の内積によって予測値を算出する点にある。ネットワーク構成は論文の Table.2 を参考にした。

モジュール構成

BranchNet

model/unstacked_deeponet.py
class BranchNet(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, u):
        return self.net(u)
  • 入力:入力関数ベクトル $\boldsymbol{u}$、in_dim 次元
  • 出力:out_dim 次元(TrunkNetと同じ次元数)
  • ネットワーク構造:1層 MLP
  • 活性化関数:$\tanh$

TrunkNet

model/unstacked_deeponet.py
class TrunkNet(nn.Module):
    def __init__(self, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, y):
        return self.net(y)
  • 入力:評価点 $y$、1 次元
  • 出力:out_dim 次元(BranchNetと同じ次元数)
  • ネットワーク構造:3層 MLP
  • 活性化間数:$\tanh$

UnstackedDeepONet

model/unstacked_deeponet.py
class UnstackedDeepONet(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim=40, out_dim=40, lr=1e-4):
        super().__init__()
        hidden_dim = max(input_dim, hidden_dim)
        self.branch = BranchNet(input_dim, hidden_dim, out_dim)
        self.trunk = TrunkNet(hidden_dim, out_dim)
        self.loss_fn = nn.MSELoss()
        self.lr = lr

    def forward(self, u, y):
        b = self.branch(u)
        t = self.trunk(y)
        return torch.sum(b * t, dim=1, keepdim=True)

    def training_step(self, batch, batch_idx):
        u, y, s = batch
        s_pred = self(u, y)
        loss = self.loss_fn(s_pred, s)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        u, y, s = batch
        s_pred = self(u, y)
        loss = self.loss_fn(s_pred, s)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
  • BranchNet と TrunkNet の出力の内積を取ってスカラー出力
  • 損失関数:Mean Square Loss function
  • 最適化手法:Adam、学習率は lr=1e-4

TensorBoard による学習曲線の可視化

プライベートサブネット下にあるEC2の lightning_logs(これはデフォルトのログ掃き出しディレクトリ名)をローカルのブラウザで確認する。

  1. EC2内で TensorBoard を立ち上げる:

    tensorboard --logdir lightning_logs --port 6006 --bind_all
    
  2. SSHローカルポートフォワード:

    ssh -L 16006:localhost:6006 <ssh エイリアス名>
    
  3. ブラウザでアクセス:

    http://localhost:16006
    

    すると以下のようにログに書き出した結果が可視化されて見れる:

    screencapture-localhost-16006-2025-05-09-12_00_30.png

接続先のEC2で、1.学習実行、2.TensorBoard、3.ポートフォワード、の3つを実施しなければならないが、1と2は tmux を用いれば、3のssh接続ひとつで全体を操作できるのでお勧め。

センサー数 m とスケールパラメータ l による精度比較

今回の実験では、入力関数のサンプリング点数 $m$ と、ガウス過程におけるスケールパラメータ $l$ を変化させたときのDeepONetの精度を系統的に評価した。
まずはその可視化結果を下に示す:

image.png

得られた結果からは、以下のような一貫した傾向が観察された。

  • スケールパラメータ $l$ が小さいほど MSE が増加する傾向にある

    小さな $l$ に対しては高周波成分を含む入力関数 $u(x)$ が生成される。これにより、関数 $u$ に対する作用素 $\mathcal{G}(u)(\boldsymbol{y})$ も急峻 な挙動を持ち、DeepONet での推定が困難になる。

    これは論文中でも指摘されている通り、関数空間のスムーズさがネットワークの学習難易度に直結するという考察と一致している。

  • サンプリング点数 $m$ の増加により精度が向上

    $m$ を増やすことで、入力関数 $u(x)$ の離散的な表現が高解像度になる。特に小さな $l$ において、粗いサンプリングでは関数の構造を十分に捉えきれずに、誤差が大きくなることが観察される。

本実装に際して得られた学びを、備忘録として以下に整理する。

次元数に応じたデータ数のスケーリング

当初はすべての $m$ に対して同一のデータ数・バッチサイズで学習していたが、$m$ が小さいほど MSE が良化するという論文とは逆の傾向が見られた。この経験から、サンプリング点数 $m$ の増加に伴い、データ数も比例して増加させる必要性を認識した。

実際に用いたコード抜粋は下記の通り:

#  construct Gaussian Random Field generator
grf = GRF()

#  sample multipliers (per m) for each dataset
b_train, b_val, b_test = 20, 2, 20

for l, m in itertools.product(ls, ms):
    print(f"\n l={l:.3g} , m={m} ")

    #  generate GRF samples
    n_train, n_val, n_test = [b * m for b in (b_train, b_val, b_test)]
    n_funcs = n_train + n_val + n_test
    U, S, x = grf.random(m, l, n_funcs=n_funcs, seed=0)

b_train, b_val, b_test はそれぞれ訓練、検査、テストのデータ数のベース値を示しており、つまり b_train$:$b_val$:$b_test$=10:1:10$ の割合でデータセットを生成するという実装である。具体的には、$m=10$ のときには b_train=20*10、b_val=2*10、b_test=20*10 となる。

また、バッチサイズも batch_size=8*m と $m$ に応じて線形でスケールするようにした。

次元数に応じた隠れ層のノード数のスケーリング

これも上と同様で、入力次元が増えるならば、隠れ層のノード数も拡張すべきという考えから来ている。論文中では $m=10,20,40$ でそれぞれ $n\_nodes=40,40,100$ としているが、今回は $n\_nodes=4^*m$ と線形スケールするように実装した。

抜粋は以下の通り。

#  model definition & training
model = UnstackedDeepONet(m, hidden_dim=4 * m, out_dim=4 * m)
trainer = pl.Trainer(
    max_epochs=500,
    callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=30)],
    accelerator="auto",
    devices=1,
    log_every_n_steps=50,
)
trainer.fit(model, train_loader, val_dataloaders=val_loader)

入力関数の生成と分割

先述した GRF クラスでは、乱数シードを指定して入力関数を生成する仕様となっているが、訓練・検証・テスト用にそれぞれ random メソッドを個別に呼び出すのは、再現性や計算効率の観点から望ましくない。一度の呼び出しで全データを生成し、スライスで分割することで、関数生成コストを抑えられる。

#  generate GRF samples
n_train, n_val, n_test = [b * m for b in (b_train, b_val, b_test)]
n_funcs = n_train + n_val + n_test
U, S, x = grf.random(m, l, n_funcs=n_funcs, seed=0)

#  build datasets with normalization
train_ds = DeepONetDataset(U[:n_train], S[:n_train], x, eval_pts)
normalize_stats = train_ds.get_normalization_stats()
val_ds = DeepONetDataset(
    U[n_train : n_train + n_val],
    S[n_train : n_train + n_val],
    x,
    eval_pts,
    normalize_stats,
)
test_ds = DeepONetDataset(
    U[n_train + n_val :], S[n_train + n_val :], x, eval_pts, normalize_stats
)

DeepONet と FNN の性能比較

2つ目の実験として、通常の FNN と DeepONet の精度比較を行う。扱う問題設定は上と同じ。ただし、$m=40$、$l=0.2$ とした。

先に、10回実行(シードは都度変更)場合の訓練・テストデータにおける平均 MSE とそのエラーバーを可視化したグラフを示す:

barplot_mlp_vs_deeponet_train_test.png

FNN は以下のネットワーク構成で実験した。

  • 隠れ層:$1〜3$(グラフ上で $h$ と表現)
  • 隠れ層のノード数:$10,\ 160,\ 2560$ (グラフ上で $w$ と表現)
  • 活性化関数:ReLU
  • 最適化手法:Adam
  • 学習率:$10^{-4}$

今回、FNN による作用素学習は入力関数ベクトル $\boldsymbol{u}=[u(x_1),\cdots,u(x_m)]\in\mathbb{R}^m$ と評価点(本問題設定ではスカラー) $y\in\mathbb{R}$ を単純に結合して、ひとつのベクトルとして入力する。
つまり、DeepONet と同様に点評価型の作用素近似として入出力を用意した。

入力構造

\text{Input vector}
=
\begin{bmatrix}
u(x_1) \\
\vdots \\
u(x_m) \\
y \\
\end{bmatrix}
\in
\mathbb{R}^{m+1}
\tag{54}

出力構造

\text{Output vector}
=
\begin{bmatrix}
\mathcal{G}(u)(y_1) \\
\vdots \\
\mathcal{G}(u)(y_k)  \\
\end{bmatrix}
\in
\mathbb{R}^{k}
\tag{55}

この実験結果を見ると、FNN は DeepONet に比べてパラメータ数が少なくても、汎化誤差が大きく、テストエラーが高い傾向にある。また、論文中でも深さを増やしてもテストエラーの改善は限定的と述べられているが、実際に隠れ層1、ノード数160の FNN が FNN の中では最良の結果となっている。

また、これは論文中では議論されていないが、訓練データ間の入力相関が高すぎるという本質的な問題が存在する。このように考える背景は以下の通り:

  • 各入力関数 $u$ は $m$ 個のサンプル点における関数値として表現され、サイズ $m$ のベクトル
  • 各 $u$ に対して、異なる評価点 $y$ に対する $\mathcal{G}(u)(y)$ を出力とするデータ点が多数生成される。
  • 故に、多くのデータ点が同じ $u$ に対してわずかに違う $y$ を加えただけの入力になる。

結果として、過学習や汎化性能の低下といった悪影響が及ぼされていると推測される。

パラメータ数と学習実行時間

以下では、パラメータ数と学習実行時間を計測した結果を整理する。

モデル (h) (w) パラメータ数
DeepONet 21.3 K
FNN 1 10 1.3 K
FNN 1 160 19.6 K
FNN 1 2560 312 K
FNN 2 10 1.4 K
FNN 2 160 45.4 K
FNN 2 2560 6.9 M
FNN 3 10 1.5 K
FNN 3 160 71.1 K
FNN 3 2560 13.4 M

image.png

DeepONet は FNN (1,160) とほぼ同程度のパラメータ数でテスト MSE がオーダーひとつ小さい。FNN を DeepONet と同程度の精度まで追い付かせようとパラメータ数を強引に増やしても計算コストが急増し、実用面でのパフォーマンスが悪化する上に、汎化性能も改善されない。

おわりに

作用素学習は、関数空間間の写像を学習するという機械学習の新興分野であり、PDEの解法や複雑な物理現象のモデリングにおいて重要な役割を果たしつつある。今回はこの中で DeepONet を紹介したが、他にも Fourier/Wavelet Neural Operator や Graph (Kernel) Neural Operator 系の手法なども存在するので、こちらも調査していきたい。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?