LoginSignup
23
33

More than 3 years have passed since last update.

機械学習モデルを逆解析する

Last updated at Posted at 2020-05-17

逆解析とは

  • 広義には,出力から入力を推定したり,何らかの方程式の解を求めること.
  • 狭義には,材料設計や化学などの分野において,欲しい物性を先に定めて,それを実現する素材の条件を求める解析のこと.

一般に,合成元の条件から,合成された物質の物性を求めることを順問題ということから,その逆方向ということで逆問題を解くとか表現することもあります.

2020-05-17_21h41_33.png

この記事の目的

機械学習モデルの逆解析を行うこと.

機械学習によって物性が予測できるようになったなら,そのモデルの出力が所定の値になるような入力値を探索することができるはずです.

しかし,入力次元(説明変数の数)が多いほど探索すべき空間は膨大になり,探索時間や計算機の性能に応じて逆解析ができないケースが発生するはずです.

  1. そこで,まずは単純なデータに対して予測モデルを作成し,逆解析が可能であることを確認します.
  2. その後,説明変数の数を増加させながら,逆解析の精度がどのように落ちていくかを調査します.(後日追記予定)

基本設計

2020-05-17_22h26_01.png

複雑なことはせずに,回帰モデルに対して出力が最小になるような入力値を探索することを試みます.

回帰モデルとしては差し当たりランダムフォレストを.

探索アルゴリズムとしてはSMBO(Sequential Model-based Global Optimization)を使用し,そのライブラリとしてhyperoptを使用します.(他にも様々な手法があります).

1. トイモデルに対する逆解析

環境

  • Python 3.6.10
  • scikit-learn 0.22.0
  • hyperopt 0.2.4

設定

簡単なモデルとして,

y= x_1 {}^2 + x_2 {}^2, \qquad (x_1, x_2) \in \mathbb{R} ^2

なる対応を考えます.明らかに最小値は$0$で,これを与える点は$(x_1, x_2) = (0,0)$です.

2020-05-17_22h37_02.png

上のグラフ生成のためのコード
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

def true_function(x, y):
    """真の関数"""
    return x ** 2 + y ** 2

X, Y = np.mgrid[-100:100, -100:100]
Z = true_function(X, Y)

plt.rcParams["font.size"] = 10  # フォントサイズを大きくする
fig = plt.figure(figsize = (12, 9))
ax = fig.add_subplot(111, projection="3d", facecolor="w")
ax.plot_surface(X, Y, Z, cmap="rainbow", rstride=3, cstride=3)
ax.set_xlabel('x1', fontsize=15)
ax.set_ylabel('x2', fontsize=15)  
ax.set_zlabel('y', fontsize=15) 
plt.show()

1.1 トレーニング・テストデータの生成

上記の対応に基づき,入出力のサンプル群を生成します.

生成したトレーニングデータを描画します.

2020-05-17_22h48_51.png

上のグラフ生成のためのコード
from sklearn.model_selection import train_test_split

def true_model(X):
    return true_function(X[:,[0]], X[:,[1]])

X = np.random.uniform(low=-100,high=100,size=(3000,2))
Y = true_model(X)

test_size = 0.3  # 分割比率

x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=test_size, random_state=0)

fig = plt.figure(figsize = (12, 9)) 
ax = plt.axes(projection ="3d")
sctt = ax.scatter3D(x_train[:,0], x_train[:,1], y_train[:,0], c=y_train[:,0], s=8, alpha = 0.6,
                    cmap = plt.get_cmap('rainbow'), marker ='^') 

plt.title("x_train, y_train") 
ax.set_xlabel('x1', fontsize=15)
ax.set_ylabel('x2', fontsize=15)  
ax.set_zlabel('y', fontsize=15) 
plt.show() 

1.2 学習・推論

先ほどのトレーニングデータでランダムフォレストを学習させ,テストデータに対して推論させた結果を描画します.

2020-05-17_22h52_46.png

概ね正しい値を推定できているようです.

1.3 回帰モデルに対する最小値の探索

hyperoptによる最小値の探索を試みます.

最小化対象となる関数を定義したのち,最小値を与える点を探索し,得られた点を先ほどの図に重ねて描画します.

2020-05-17_23h14_09.png

上のグラフ生成のためのコード
from hyperopt import hp
from hyperopt import fmin
from hyperopt import tpe

def objective_hyperopt_by_reg(args):
    """hyperopt用の目的関数"""
    global reg
    x, y = args
    return float(reg.predict([[x,y]]))

def hyperopt_exe():
    """hyperoptによる最適化の実行"""
    # 探索空間の設定
    space = [
        hp.uniform('x', -100, 100),
        hp.uniform('y', -100, 100)
    ]

    # 探索開始
    best = fmin(objective_hyperopt_by_reg, space, algo=tpe.suggest, max_evals=500)
    return best

best = hyperopt_exe()
print(f"best: {best}")

fig = plt.figure(figsize = (12, 9)) 
ax = plt.axes(projection ="3d")
sctt = ax.scatter3D(x_test[:,0], x_test[:,1], y_test[:,0], c=y_test[:,0], s=6, alpha = 0.5,
                    cmap = plt.get_cmap('rainbow'), marker ='^')
ax.scatter3D([best["x"]], [best["y"]], [objective_hyperopt_by_reg((best["x"], best["y"]))], 
                    c="red", s=250, marker="*", label="minimum") 

plt.title("x_test, y_pred", fontsize=18) 
ax.set_xlabel('x1', fontsize=15)
ax.set_ylabel('x2', fontsize=15)  
ax.set_zlabel('y', fontsize=15) 
plt.legend(fontsize=15)
plt.show() 

output
100%|██████████████████████████████████████████████| 500/500 [00:09<00:00, 52.54trial/s, best loss: 27.169204190118908]
best: {'x': -0.6924078319870626, 'y': -1.1731945130395605}

最小点に近い点が得られました.

まとめと課題

この記事では,回帰モデルの学習 ⇒ 逆解析 の手続きを単純なデータに対して実行しました.

今回は入力次元が小さかったため,偶然にもうまく最小値を探索できましたが,実際のデータに適用にするには様々な課題が生じると予想されます.

  • データが少なすぎて学習が足りない
  • 不適切な回帰モデルを選択してしまったために現実にはあり得ない解を導いてしまう.
  • 説明変数の次元が高すぎてうまく探索できない・あるいは極所解にトラップされる.
  • データの分布に偏りがあり,疎な領域における学習が不十分.
  • 過大なノイズを学習してしまう.

また,逆問題は適切に扱わないと極めてナンセンスな解析をしてしまう恐れがあると予想されます.

  • そもそも解が存在しない
  • 解が唯一ではない
  • 真の分布が不連続であり,解が安定でない

このような問題があるにもかかわらず逆問題を設定してしまったが故に,無駄な労力を費やすことは絶対に避けたいところです.

参考

23
33
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
23
33