LoginSignup
1
1

More than 3 years have passed since last update.

バネのダイナミクス予測

Posted at

こちらの文献を参考に、下図の状況のようなバネにつながれた物体のダイナミクスのNNによる学習・予測を試してみます。

image.png

具体的には時刻$t_n$までの各物体の位置$\mathbf{r(t_n)}$、速度$\mathbf{\dot{\mathbf{r}}(t_n)}$、加速度$\mathbf{\ddot{\mathbf{r}}(t_n)}$を入力として、一期先時刻$t_{n+1}$での加速度変化($\mathbf{\ddot{\mathbf{r}}(t_{n+1})}$-$\mathbf{\ddot{\mathbf{r}}(t_n)}$)を予測します。
※文献の詳細は読めていないので、思想的なところだけ参考に試していきます。

大まかな流れ

  1. 運動方程式から数値計算を行いていくつかの初期条件で教師データを生成
  2. 1で作成した教師データを用いてNNモデルを学習
  3. 適当な状態からその先の状態をNNモデルから逐次的に予測し、それらしい運動になるか確認する。

1. 教師データの生成

簡易のため、適当な初期状態を生成後、運動方程式からEuler法により逐次計算を行います。

  • 位置$\mathbf{r}_{i, j}$の物体の運動方程式

image.png

上下左右4方向からの力の和で、$m$は質量、$\mathbf{r}^{eq}_{i,j,i\prime,j\prime}$は$(i,j)$の物体から$(i\prime, j\prime)$の物体を向いた平衡位置のベクトル、$k$はばね定数です。端っこの連結がない部分は除外します。

上記式から、以下Euler法で一期先の状態を計算します

image.png

  • 物体定義と初期化
@dataclass
class Obj:
    idx: int
    pos: np.array
    vel: np.array
    accel: np.array
    force: np.array
    connected_objs: List = None
    mass: float = 1.0
    energy: float = 0.0

def init_objs_2d():
    objs = {}
    # 物体作成
    for i in range(4):
        for j in range(4):
            objs[(i,j)] = Obj(
                idx=i*4+j,
                pos=np.array([i, j])+0.05*np.random.randn(2),
                vel=0.01*np.random.randn(2),
                accel=np.zeros(2),
                force=np.zeros(2),
            )

    # バネ連結情報付与
    for i in range(4):
        for j in range(4):
            connected_objs = []
            if i-1 >= 0:
                connected_objs.append(objs[(i-1,j)])
            if i+1 <= 3:
                connected_objs.append(objs[(i+1,j)])
            if j-1 >= 0:
                connected_objs.append(objs[(i,j-1)])
            if j+1 <= 3:
                connected_objs.append(objs[(i,j+1)])

            objs[(i,j)].connected_objs = connected_objs
    return objs
  • 各種状態更新関数
def force(pos1: np.array, pos2: np.array):
    """pos2の物体からpos1の物体に働くフックの法則による力"""
    dist = np.sqrt(((pos1-pos2)**2).sum())
    r_normed = (pos2-pos1)/dist
    equilibrium_vec = EQUILIBRIUM_LENGTH*r_normed
    f = -K*(equilibrium_vec-(pos2-pos1))
    return f

def _update_force_on_obj(obj: Obj):
    """物体に働く力の更新"""
    forces = [force(obj.pos, obj2.pos) for obj2 in obj.connected_objs]
    obj.force = sum(forces)

def _update_accel(obj: Obj):
    """物体の加速度更新"""
    obj.accel = obj.force / obj.mass

def _update_velocity(obj: Obj):
    """物体の速度更新"""
    obj.vel = obj.vel + DT*obj.accel

def _update_pos(obj: Obj):
    """物体の位置更新"""
    obj.pos = obj.pos + DT*obj.vel

def update_states(objs: List[Obj]):
    """全物体の状態更新"""
    [_update_force_on_obj(obj) for obj in objs.values()]
    [_update_accel(obj) for obj in objs.values()]
    [_update_velocity(obj) for obj in objs.values()]
    [_update_pos(obj) for obj in objs.values()]
    [_calc_energy(obj) for obj in objs.values()]
  • 計算結果の可視化

spring_dynamics_euler_calculated2.gif

赤い矢印が物体に働いているフックの法則による力のベクトルで、緑の矢印が物体の速度ベクトルです。

全力学的エネルギーの時間変化も確認してみます。

image.png

保存されていないみたいですが、気にせず進みます。

2. ニューラルネットワークモデルの学習

1で生成したデータから、各物体の時刻$t_n$までの各物体の状態を入力、各物体の一期先$t_{n+1}$の加速度変化を出力としたグラフニューラルネットワークを学習します。

モデル構造

モデル全体像イメージは冒頭で述べた文献記載の下図モデルを大まかに参考にし、GNがグラフニューラルネットワーク(GNN)モジュール、エンコーダー部分の詳細は分からなかったのでエンコーダーは除外、デコーダーは単純に全結合層にしました。
グラフにおいて、物体がノード、物体をつなげてるバネがエッジで、ノード(物体)に特徴量(位置、速度等)が紐づいているイメージです。

image.png

ref

  • GNNモジュール
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GNNModule(torch.nn.Module):
    def __init__(self, n_features, out_dim, intermediate_dim=24):
        super(GNNModule, self).__init__()
        self.conv1 = GCNConv(n_features, intermediate_dim)
        self.conv2 = GCNConv(intermediate_dim, out_dim)

    def forward(self, x, edge_index):
        x2 = self.conv1(x, edge_index)
        x3 = F.relu(x2)
        x4 = F.dropout(x3, training=self.training)
        x5 = self.conv2(x4, edge_index)
        return x5

(バッチサイズ, ノード数(物体数), n_features)のfeatureと、(2, エッジ数(バネの数))のedge_indexを入力すると、(バッチサイズ、out_dim)のテンソルが出力されます。

  • 全体ネットワーク
class SpringDynamicsNet(torch.nn.Module):
    def __init__(self, n_features, out_dim, n_gnn_module=4, intermediate_dim=24):
        super(SpringDynamicsNet, self).__init__()
        self._intermediate_gnn_modules = nn.Sequential(
            *[GNNModule(n_features, n_features) for _ in range(n_gnn_module)]
        )
        self._fc = nn.Linear(n_features, out_dim)

    def forward(self, x, edge_index):
        for sping_module in self._intermediate_gnn_modules:
            residual = x
            gnn_out = sping_module(x, edge_index)
            x = residual + gnn_out
        return self._fc(x)

(バッチサイズ, ノード数(物体数), n_features)のfeatureと、(2, エッジ数(バネの数))のedge_indexを入力すると、(バッチサイズ、out_dim)のテンソルが出力されます。

入出力データ

入力の状態は時刻$t_n$における各物体の位置・加速度、$t_n$から$N$期前までの各物体の速度とし、今回は$N$は3で行い、出力は1期先の加速度変化としました。
具体的に書くと、物体1つに対し、入力は
$[x(t), y(y), v_x(t), v_y(t), v_x(t-1), v_y(t-1), v_x(t-2), v_y(t-2), v_x(t-3), v_y(t-3), a_x(t), a_y(t)]$
出力は
$[a_x(t+1)-a_x(t), a_y(t+1)-a_y(t)]$
といった感じです。
これが物体数(16)あるので、入力が(バッチサイズ、16, 12)、出力が(バッチサイズ、16, 2)のテンソルになります。
モデル入力には上記に加え、下記の物体連結情報のエッジデータも渡します。

エッジデータ作成

連結している物体のインデックスペアリストを作ります。

edge_index = {}
for obj in objs.values():
    for obj2 in obj.connected_objs:
        edge_index[(obj.idx, obj2.idx)] = 1

edge_index = list(edge_index.keys())
edge_index = torch.tensor(edge_index).T

学習

spring_dynamics_net = SpringDynamicsNet(
    n_features=FEATURE_DIM,
    out_dim=2,
    n_gnn_module=4,
    intermediate_dim=24
)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(spring_dynamics_net.parameters(), lr=0.01, weight_decay=5e-4)

BATCH_SIZE = 32
N_EPOCHS = 20

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
spring_dynamics_net.to(device)
spring_dynamics_net.train()

loss_histroy = []

for epoch in pb(range(N_EPOCHS)):
    epoch_loss = 0
    for i in range(feats.shape[0]//BATCH_SIZE):
        feat_batch = feats[i*BATCH_SIZE:(i+1)*BATCH_SIZE, :, :]
        out_batch = outs[i*BATCH_SIZE:(i+1)*BATCH_SIZE, :, :]
        optimizer.zero_grad()
        preds = spring_dynamics_net(feat_batch, edge_index) 
        loss = criterion(out_batch, preds)
        loss.backward(retain_graph=True)
        optimizer.step()
        epoch_loss += loss.item()*feat_batch.shape[0]
    loss_histroy.append(epoch_loss)
  • 損失時系列

image.png

学習できてるようです。

3. 適当な状態からその先の状態をNNモデルから逐次的に予測

全体コード

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1