17
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

scipy.optimize.least_squaresではうまく最適化できない例とその対処法

Last updated at Posted at 2016-11-07

scipyではoptimize.least_squaresを使うことで、非線形関数のパラメーターをデータにフィットさせることができます。しかし、非線形関数の形によっては、最適なパラメーターを求めることができないことがあります。なぜなら、optimize.least_squaresでは局所的な最適解しか求めることができないからです。

今回は、optimize.least_squaresが局所最適解に陥ってしまう例を挙げ、optimize.basinhoppingを使って大域的最適解を求めてみます。

バージョン:

  • Python 3.5.1
  • numpy (1.11.1)
  • scipy (0.18.0)

うまく最適化できない例

$a$をパラメーターとしてもつ次のような関数を考えます。

$y(x)=\frac{1}{100}(x-3a)(2x-a)(3x+a)(x+2a)$

$a=2$のときにノイズがのったデータが得られたとします。


import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')

seed = 0
np.random.seed(seed)

def y(x, a):
    return (x-3.*a) * (2.*x-a) * (3.*x+a) * (x+2.*a) / 100.

a_orig = 2.
xs = np.linspace(-5, 7, 1000)
ys = y(xs,a_orig)

num_data = 30
data_x = np.random.uniform(-5, 5, num_data)
data_y = y(data_x, a_orig) + np.random.normal(0, 0.5, num_data)

plt.plot(xs, ys, label='true a = %.2f'%(a_orig))
plt.plot(data_x, data_y, 'o', label='data')
plt.legend()

qiita_1.png

これに対して、optimize.least_squaresでパラメーターを求めてみます。

from scipy.optimize import least_squares

def calc_residuals(params, data_x, data_y):
    model_y = y(data_x, params[0])
    return model_y - data_y

a_init = -3
res = least_squares(calc_residuals, np.array([a_init]), args=(data_x, data_y))

a_fit = res.x[0]
ys_fit = y(xs,a_fit)

plt.plot(xs, ys, label='true a = %.2f'%(a_orig))
plt.plot(xs, ys_fit, label='fit a = %.2f'%(a_fit))
plt.plot(data_x, data_y, 'o')
plt.legend()

qiita_2.png

パラメーターの初期値を$a_0 = -3$にしたところ、うまくデータにフィットしませんでした。

うまく最適化できない理由

パラメーターの初期値によってどのように結果が変わるかを調べてみると、

a_inits = np.linspace(-4, 4, 1000)
a_fits = np.zeros(1000)
for i, a_init in enumerate(a_inits):
    res = least_squares(calc_residuals, np.array([a_init]), args=(data_x, data_y))
    a_fits[i] = res.x[0]

plt.plot(a_inits, a_fits)
plt.xlabel("initial value")
plt.ylabel("optimized value")

qiita_3.png

初期値が負だと、局所的に最適なパラメーターに陥ってしまっています。この理由は、パラメーターの値と残差の関係を見てみると分かります。下図のように、パラメーターに対して極小値がふたつあるせいで、初期値に依存して結果が変わってしまうのです。

def calc_cost(params, data_x, data_y):
    residuals = calc_residuals(params, data_x, data_y)
    return (residuals * residuals).sum()

costs = np.zeros(1000)
for i, a in enumerate(a_inits):
    costs[i] = calc_cost(np.array([a]), data_x, data_y)
plt.plot(a_inits, costs)
plt.xlabel("parameter")
plt.ylabel("sum of squares")

qiita_4.png

うまく最適化する方法

大域的に最適なパラメーターを求めるには、色々な初期値から計算してみればよい、ということになります。これをいい感じにやる方法としてscipyには、optimize.basinhoppingがあります。それではやってみましょう。

from scipy.optimize import basinhopping
a_init = -3.0
minimizer_kwargs = {"args":(data_x, data_y)}
res = basinhopping(calc_cost, np.array([a_init]),stepsize=2.,minimizer_kwargs=minimizer_kwargs)
print(res.x)

a_fit = res.x[0]
ys_fit = y(xs,a_fit)

plt.plot(xs, ys, label='true a = %.2f'%(a_orig))
plt.plot(xs, ys_fit, label='fit by basin-hopping a = %.2f'%(a_fit))
plt.plot(data_x, data_y, 'o')
plt.legend()

qiita_5.png

うまくパラメーターが求まりました。コツは引数のstepsizeです。この引数がどれくらい初期値を大きく変えるかを決めます。

17
20
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?