1
0

Phy-SOによるシンボリック回帰 on colaboratory

Posted at

Phy-SO

(以下arxiv abstractの日本語訳)
記号回帰は、データに適合する解析式の探索を自動化するアルゴリズムの研究である。最近の深層学習の進歩により、このようなアプローチに再び関心が集まっているが、記号回帰手法の開発は、データに関連するユニットのために重要な追加制約がある物理学には焦点が当てられていない。ここでは、物理データから解析的な記号式を復元するための物理記号最適化フレームワークであるΦ-SOを、ユニットの制約を学習することによって、深層強化学習技術を用いて紹介する。我々のシステムは、物理単位が構造的に一貫している解を提案するように、一から構築されている。これは物理的に不可能な解を排除するだけでなく、次元解析の "文法的 "規則が方程式生成器の自由度を大幅に制限するため、パフォーマンスが大幅に向上する。このアルゴリズムは、ノイズのないデータのフィッティングに使用することができ、例えば物理モデルの解析的特性を導こうとする場合に有用であり、またノイズのあるデータの解析的近似を得るためにも使用できる。我々は、ファインマン物理学講義やその他の物理学の教科書に掲載されている方程式の標準的なベンチマークで我々の機械をテストし、ノイズ(0.1%を超える)の存在下で最先端の性能を達成し、かなりの(10%)ノイズの存在下でもロバストであることを示す。宇宙物理学の例でその能力を紹介する。

image.png
image.png

doumentation:https://physo.readthedocs.io/en/latest/index.html
paper:https://physo.readthedocs.io/en/latest/index.html
github:https://github.com/WassimTenachi/PhySO

ノイズ存在下でもcurve fittingを関数から行えるという理解をしている。githubで公開されているので、本手法をcolaboratory上で試してみた。

On Colaboratory

githubではいくつかデモのnotebookが公開されている。
colaboratory上で動かすにはmatplotlibのlatexでエラーを吐いたので、改善法を記載する。
テスト例を見たい方は↓のリンク参照
Open In Colab

ただし、不機嫌なときはエラーを吐く←?

まずgithubからphy-SOをcloneし、PATHを追加する。

!git clone https://github.com/WassimTenachi/PhySO

import sys
sys.path.append("/content/PhySO")

そしてlatexエラーを回避するため、必要?なものをインストールする。

!apt install texlive-fonts-recommended texlive-fonts-extra cm-super dvipng

matplotlibのlatexを設定する。

import matplotlib
from matplotlib import rc, rcParams
import matplotlib.pyplot as plt
rc('text', usetex=True)

これで準備完了!

使い方

まだチュートリアルをこなしただけだが、わかる範囲で記載する。

とりあえずダミーデータを作成
ダミーデータは

E = mgz + mv^2
m=1.234, g=9,807

とした。

import numpy as np
import pandas as pd
import plotly.express as px


z = np.random.uniform(-10, 10, 100)
v = np.random.uniform(-10, 10, 100)
X = np.stack((z, v), axis=0)
y = 1.234*9.807*z + 1.234*v**2

df = pd.DataFrame()
df["z"]=z
df["v"]=v
df["E"]=y

px.scatter_3d(data_frame=df,x="z",y="v",z="E")

newplot.png

Phy-SOは様々な設定があるが、チュートリアルをそのまま使用する。

import physo

expression, logs = physo.SR(X, y,
                            X_names = [ "z"       , "v"        ],
                            X_units = [ [1, 0, 0] , [1, -1, 0] ],
                            y_name  = "E",
                            y_units = [2, -2, 1],
                            fixed_consts       = [ 1.      ],
                            fixed_consts_units = [ [0,0,0] ],
                            free_consts_names = [ "m"       , "g"        ],
                            free_consts_units = [ [0, 0, 1] , [1, -2, 0] ],
                            op_names = ["mul", "add", "sub", "div", "inv", "n2", "sqrt", "neg", "exp", "log", "sin", "cos"]
)

これで学習を開始する。学習中はエポック毎に結果が表示される。

SR task started...
=========== Epoch 00000 ===========
-> Time 14.29 s

Overall best  at R=0.535048
-> Raw expression : 
                                 2
  ⎛                1            ⎞ 
m⋅⎜v + ─────────────────────────⎟ 
  ⎜                          1.0⎟ 
  ⎜    -log(-1.0 + sin(1.0))⋅───⎟ 
  ⎝                           v ⎠ 

Best of epoch at R=0.535048
-> Raw expression : 
                                 2
  ⎛                1            ⎞ 
m⋅⎜v + ─────────────────────────⎟ 
  ⎜                          1.0⎟ 
  ⎜    -log(-1.0 + sin(1.0))⋅───⎟ 
  ⎝                           v ⎠ 


=========== Epoch 00001 ===========
-> Time 11.77 s

Overall best  at R=0.605056
-> Raw expression : 
         -g          
─────────────────────
⎛          ⎛   1   ⎞⎞
⎜-1.0 - sin⎜───────⎟⎟
⎜          ⎜    0.5⎟⎟
⎜          ⎝-2.0   ⎠⎟
⎜───────────────────⎟
⎝        m⋅z        ⎠

Best of epoch at R=0.605056
-> Raw expression : 
         -g          
─────────────────────
⎛          ⎛   1   ⎞⎞
⎜-1.0 - sin⎜───────⎟⎟
⎜          ⎜    0.5⎟⎟
⎜          ⎝-2.0   ⎠⎟
⎜───────────────────⎟
⎝        m⋅z        ⎠


=========== Epoch 00002 ===========
.
.
.

CPUで動かしているが非常に高速。
optimizeが完了すると以下のグラフが表示される。
image.png
image.png

結果は以下で表示できる。
関数の複雑さを考慮しており、必要な条件に応じて選択できそう

pareto_front_complexities, pareto_front_programs, pareto_front_r, pareto_front_rmse = logs.get_pareto_front()

for prog in pareto_front_programs:
    prog.show_infix(do_simplify=True)
    free_consts = prog.free_const_values.detach().cpu().numpy()
    for i in range (len(free_consts)):
        print("%s = %f"%(prog.library.free_const_names[i], free_consts[i]))

image.png

デモのnotebookに載っていた関数を使用するとRMSEとcomplexityを示せて非常にわかりやすい。
今回では一番下のfitting curveがRMSEが0で実数部を除去すると元の関数と一致することがわかる。

def plot_pareto_front(run_logger,
                      do_simplify                   = True,
                      show_superparent_at_beginning = True,
                      eq_text_size                  = 12,
                      delta_xlim                    = [0, 5 ],
                      delta_ylim                    = [0, 15],
                      frac_delta_equ                = [0.03, 0.03],
                      figsize                       = (20, 10),
                     ):

    pareto_front_complexities, pareto_front_programs, pareto_front_r, pareto_front_rmse = run_logger.get_pareto_front()

    pareto_front_rmse = pareto_front_rmse
    # Fig params
    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    # enables new_dummy_symbol = "\square"
    plt.rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{xcolor}')
    plt.rc('font', size=32)

    # Fig
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.plot(pareto_front_complexities, pareto_front_rmse, 'r-')
    ax.plot(pareto_front_complexities, pareto_front_rmse, 'ro')

    # Limits
    xmin = pareto_front_complexities.min() + delta_xlim[0]
    xmax = pareto_front_complexities.max() + delta_xlim[1]
    ymin = pareto_front_rmse.min() + delta_ylim[0]
    ymax = pareto_front_rmse.max() + delta_ylim[1]
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    # Axes labels
    ax.set_xlabel("Expression complexity")
    ax.set_ylabel("RMSE")


    for i_prog in range (len(pareto_front_programs)):
        prog = pareto_front_programs[i_prog]

        text_pos  = [pareto_front_complexities[i_prog] + frac_delta_equ[0]*(xmax-xmin),
                     pareto_front_rmse[i_prog]         + frac_delta_equ[1]*(ymax-ymin)]
        # Getting latex expr
        latex_str = prog.get_infix_latex(do_simplify = do_simplify)
        # Adding "superparent =" before program to make it pretty
        if show_superparent_at_beginning:
            latex_str = prog.library.superparent.name + ' =' + latex_str


        ax.text(text_pos[0], text_pos[1], f'${latex_str}$', size = eq_text_size)


plot_pareto_front(logs)

image.png

"" 所感
非常に有用そうなである。手近なデータで実際に使用してみたい。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0