概要
scipy.optimize.curve_fitは、SciPyライブラリの最適化モジュールに含まれる関数で、データに最もよくフィットするパラメータを非線形な関数から推定するために使うことができます。scipyのカーブフィットで正弦波の振幅と位相を求めてみます。
他の手法との比較
正弦波の振幅と位相をScipyのカーブフィットで求める
https://qiita.com/nnn112358/items/35cfee57bd1d8f21147b
正弦波の振幅と位相をSciPyのFFTで求める
https://qiita.com/nnn112358/items/dfded784541045f988d4
正弦波の振幅と位相をLSTMで求める
https://qiita.com/nnn112358/items/1c9b077a7a413ee1dd9f
正弦波の振幅と位相をRNNで求める
https://qiita.com/nnn112358/items/63b53fb192b7f10d509a
curve_fitの説明
scipy.optimize.curve_fit(f, xdata, ydata, p0=None, sigma=None, absolute_sigma=False, bounds=(-inf, inf), method=None, jac=None, **kwargs)
f: フィッティング関数。独立変数(通常x)とフィットさせるパラメータ(a, b, cなど)を引数にとる関数。
xdata: 独立変数のデータ。配列またはリスト形式。
ydata: 従属変数のデータ。配列またはリスト形式。
p0: フィットの初期推定値。フィットさせるパラメータの初期値のリストまたはタプル。省略すると、すべてのパラメータが1に設定される。
sigma: yデータの標準偏差。これにより、重み付きフィッティングが可能になる。yデータと同じ形状の配列。
absolute_sigma: Trueの場合、sigmaを絶対的な値として使用し、Falseの場合、相対的な値として使用する。
bounds: パラメータの下限と上限を指定するタプル。デフォルトは(-inf, inf)で、制限なし。
method: 最適化方法を指定する文字列。デフォルトは'lm'(レーベンバーグ・マルカート法)だが、他に'Trust Region Reflective'や'Dogbox'などがある。
jac: ヤコビ行列(フィッティング関数の導関数)。特定の最適化方法でのみ使用される。
kwargs: 他のオプションパラメータ。
返り値
popt: フィットしたパラメータの最適値を含む配列。
pcov: パラメータの共分散行列。対角要素はパラメータ推定の分散を表す。
注意点
初期推定値: フィットがうまく収束するためには、適切な初期推定値が大切です。データのスケール: データのスケールが非常に大きい場合や小さい場合、正規化することでフィットが改善することがあります。 methodパラメータを変更することで、異なる最適化アルゴリズムを試すことができます。
Python開発環境のインストール
Python環境として、Minicondaをインストールする手順を説明します。
WSLにUbuntu22.04 をインストールして使用しました。
MinicondaのLinuxへのインストール手順
-
Minicondaのインストーラをダウンロード:
まず、Minicondaの公式ウェブサイトからインストーラをダウンロードします。以下のコマンドで最新版のMinicondaインストーラをダウンロードします。wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
-
インストーラの実行:
ダウンロードしたスクリプトを実行してMinicondaをインストールします。インストール先を指定する場合は、オプションでディレクトリを指定できます。bash Miniconda3-latest-Linux-x86_64.sh
インストールプロセスが開始されると、いくつかのプロンプトが表示されます。これに従って進めてください。
- 使用許諾契約書が表示されたら、スペースキーを押して全文を表示し、
yes
と入力して同意します。 - インストールディレクトリの確認が表示されるので、デフォルトのままで良ければそのままエンターキーを押します。カスタムディレクトリにインストールしたい場合は、パスを入力してエンターキーを押します。
- インストールが完了したら、
conda
コマンドを使えるようにするために、環境変数の設定を行うかどうかのプロンプトが表示されます。これもyes
と入力して進めます。
- 使用許諾契約書が表示されたら、スペースキーを押して全文を表示し、
-
環境変数の設定:
インストールが完了したら、Minicondaのバイナリディレクトリを環境変数PATH
に追加します。多くの場合、インストールスクリプトが自動的にこれを行いますが、手動で行う場合は以下のコマンドを実行します。export PATH=~/miniconda3/bin:$PATH
この設定を永続化するために、
~/.bashrc
(または使用しているシェルの設定ファイル)に追加します。echo 'export PATH=~/miniconda3/bin:$PATH' >> ~/.bashrc source ~/.bashrc
-
インストールの確認:
conda
コマンドを実行して、Minicondaのインストールが正しく行われたことを確認します。conda --version
これでMinicondaのバージョンが表示されれば、インストールは成功です。
-
Condaの初期設定とアップデート:
インストール後、初期設定を行い、Conda自体とデフォルトパッケージをアップデートします。conda init conda update conda
これでMinicondaのインストールは完了です。
仮想環境の作成
Minicondaのインストールが完了したら、仮想環境を作成します。
仮想環境の作成:
conda create -n new_env python=3.8 matplotlib scikit-learn pykalman
仮想環境のアクティベート:
conda activate new_env
仮想環境のディアクティベート:
仮想環境を終了するには、以下のコマンドを実行します。
conda deactivate
サンプルソース
正弦波データにノイズを加え、そのデータに対してscipy.optimize.curve_fit関数を使用して最適な正弦波のパラメータを推定し、結果をグラフに表示してみます。
$ python curve_fit.py
Original amplitude: 2.0000
Fitted amplitude: 2.0329
Original phase: 0.7854
Fitted phase: 0.7633
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
# 正弦波関数の定義
def sine_wave(t, A, phi):
return A * np.sin(2 * np.pi * known_freq * t + phi)
# パラメータ設定
N = 1000 # サンプル数
T = 0.1 # サンプリング間隔
t = np.linspace(0.0, N*T, N) # 時間軸の生成
# 既知の周期と周波数
known_period = 5.0 # 既知の周期
known_freq = 1.0 / known_period # 周波数の計算
# 元の正弦波を生成
amplitude_original = 2.0 # 元の正弦波の振幅
phase_original = np.pi / 4 # 元の正弦波の位相
y_original = amplitude_original * np.sin(2.0 * np.pi * known_freq * t + phase_original) # 元の正弦波データ
# ノイズを加える
np.random.seed(42) # 再現性のためにランダムシードを設定
noise = np.random.normal(0, 0.5, N) # 正規分布に従うノイズを生成
y_noisy = y_original + noise # ノイズを加えたデータ
# CurveFitを使用してパラメータを推定
popt, _ = curve_fit(sine_wave, t, y_noisy, p0=[1.0, 0.0]) # カーブフィットを実行
amplitude_fit, phase_fit = popt # 最適な振幅と位相を取得
# フィッティングした正弦波を生成
y_fit = sine_wave(t, amplitude_fit, phase_fit) # フィッティング結果を用いた正弦波データ
# グラフ描画
plt.figure(figsize=(12, 8)) # グラフのサイズを設定
# ノイズデータ、元の正弦波、フィッティングした正弦波をプロット
plt.plot(t, y_noisy, label='Noisy Data', alpha=0.5) # ノイズデータ
plt.plot(t, y_original, label='Original Sine Wave', linewidth=2) # 元の正弦波
plt.plot(t, y_fit, label='Fitted Sine Wave', linestyle='--') # フィッティングした正弦波
plt.xlabel('Time') # X軸ラベル
plt.ylabel('Amplitude') # Y軸ラベル
plt.title('Sine Wave Fitting with CurveFit') # グラフタイトル
plt.legend() # 凡例を表示
plt.tight_layout() # レイアウトを自動調整
plt.show() # グラフを表示
# 推定されたパラメータと元のパラメータを表示
print(f"Original amplitude: {amplitude_original:.4f}")
print(f"Fitted amplitude: {amplitude_fit:.4f}")
print(f"Original phase: {phase_original:.4f}")
print(f"Fitted phase: {phase_fit:.4f}")
参考文献
SciPy API>Optimization and root finding (scipy.optimize)>curve_fit
https://docs.scipy.org/doc/scipy-1.14.0/reference/generated/scipy.optimize.curve_fit.html