9
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Rustその2Advent Calendar 2019

Day 23

RustでGauss-Newton法

Last updated at Posted at 2019-12-22

(一回間違って投稿してしまったので,一度消して修正した上で最投稿しました.既に見てくださった方には申し訳ありません.)

何か作るつもりだったのですが,他にもやりたいことが発生して時間が取れなかったのと,実装力・設計力不足により,作ることができなかったのでGauss-Newton法をやることでお茶を濁します.

非線形最小二乗法

そもそもGauss-Newton法とは非線形最小二乗法を解く手法の一つですが,そもそも非線形最小二乗法とは何かについて簡単に記述します.

最小二乗法は,例の『ディープラーニングは「最小二乗法」』1でお馴染みだと思いますが,説明変数 $z$,被説明変数 $y$ に対してモデル関数 $f$ とパラメータ $x$ 与えたときの,以下の式で表されるような,最小二乗誤差 $E$ を最小化するようにパラメータ $x$ を決めるような方法のことです.

\begin{aligned}
E &=\frac{1}{2}||e(x)||^2 \\
e(x) &=y-f(z;x)
\end{aligned}

その中でも,$f$ が非線形関数であるような場合を特に,非線形最小二乗法というわけです.

Gauss-Newton法

非線形最小二乗法を解く手法の一つです.
そもそも,$E$ の勾配 $g$ と Hesse行列 $H$ から,$x$ の更新量 $\delta x$ を以下の式で求めるというNewton法という手法があります.

\begin{aligned}
g_i &= \frac{\partial E}{\partial x_i} \\
H_{ij} &= \frac{\partial^2 E}{\partial x_i\partial x_j} \\
\delta x &= -H^{-1}g
\end{aligned}

ただこれには,Hesse行列 $H$ の計算が重いことや,そもそもHesse行列 $H$ が正則でない場合があるという問題があります.
そこで,Gauss-Newton近似という,以下の近似を用いることで,Newton法の問題を解決したのがGauss-Newton法というわけです( $J$ は $e(x)$ のJacobi行列).

\begin{aligned}
J_{ij} &=\frac{\partial e_i}{\partial x_j} \\
H &\approx J^{\mathrm{T}}J
\end{aligned}

また,連鎖律から以下の式も成り立つことがわかります.

g= J^{\mathrm{T}}e

ということで,要するに,以下の式で求まる更新量 $\delta x$ でパラメータを更新していきます.

\delta x=-\left(J^{\mathrm{T}}J\right)^{-1}J^{\mathrm{T}}e

数学の話終わり.

Rustの線形代数ライブラリ

RustでGauss-Newton法を実装しようと思った場合,(自分でフルスクラッチで書くのでなければ)既存の何かしらの線形代数ライブラリを使うことになるわけですが,以下の2つが有力かなと思いました(ぱっと調べた感じの単なる印象ですが).

その他,気になったところ(使ったことはない).

  • cgmath
    • 今回の用途には向かなそうだが,クォータニオンをサポートしていたり3次元処理とかなら,非常に役に立ちそう.crates.ioを見るとダウンロード数が多い(100万超)ので,実際ゲーム系のライブラリで使われていたりするのだろうか.
  • sprs
    • まだ珍しい疎行列ライブラリ.ndarray互換で,CSC/CRC形式の両方をサポート.
  • arrayfire
    • ArrayFire (C++製)のラッパ.GPUが使えるのは貴重かな?

https://www.reddit.com/r/rust/comments/63wts9/why_are_there_so_many_linear_algebra_crates_which/ に挙がっているのだと,他は軒並み開発終了ですかね…….


  • argmin
    • 線形代数ライブラリではなく,最適化ライブラリなので別枠で.WIPらしいので今後に期待.

実装

ということで,個人的に有力と感じた,ndarrayおよびnalgebraでGauss-Newton法を実装しました.
とはいえ,両方載せても,長くなるだけで大して面白くならない(できない)ので,ndarrayを使った実装だけのせます.
また,あまり重要でないと思ったところを適宜省略したりするので,全体を見たい方は,コードを https://github.com/eduidl/gauss-newton-rs にあげているのでそちらを参照してください.

気持ち抽象化しているのは,何か作ろうと思った物の残骸です.

Problem

後で実装するソルバが欲しい情報(Jacobi行列等)それぞれを返すようにtriatを定義しておきます.

problem.rs
pub trait Problem
where
    Array1<Self::T>: Norm<Output = Self::T>,
{
    type T: Scalar;

    fn params(&self) -> &Array1<Self::T>;
    fn update_params(&mut self, delta_param: &Array1<Self::T>);
    fn error_vector(&self) -> Array1<Self::T>;
    fn squared_error(&self) -> Self::T {
        self.error_vector().norm_l2()
    }
    fn jacobian(&self) -> Array2<Self::T>;
}

Solver

Gauss-Newton法の本体です. step がメイン処理です.終了条件のチェックが行数をかさ増ししていますが,中身は単純で,概ね先ほどの説明どおりです.
一つ少し違いがあって,先ほどの説明で以下のような式を示しましたが,$\left(J^{\mathrm{T}}J\right)^{-1}$ を求めて行列積を計算するというのは,精度上あまりよくないとされています.

\delta x=-\left(J^{\mathrm{T}}J\right)^{-1}J^{\mathrm{T}}e

そのため,実際には以下の式を solvec_inplace を用いて解いています.

\left(J^{\mathrm{T}}J\right)\delta x=-J^{\mathrm{T}}e

A (= $J^{\mathrm{T}}J$ )が正定値対称行列であるため,Cholesky分解を利用して解いているわけです.

solver.rs
pub struct GaussNewton<P, T> {
    pub problem: P,
    converged: bool,
    iters: usize,
    eps: T,
}

#[allow(non_snake_case)]
impl<P: Problem> GaussNewton<P, P::T>
where
    P::T: Float + From<f32>,
    Array1<P::T>: Norm<Output = P::T>,
    Array2<P::T>: Norm<Output = P::T> + SolveC<P::T>,
{
    pub fn new(problem: P) -> Self {
        Self {
            problem,
            converged: false,
            iters: 0,
            eps: 1e-6.into(),
        }
    }

    pub fn step(&mut self) {
        assert!(!self.converged);
        self.iters += 1;

        let prev_squared_error = self.problem.squared_error();

        let J = self.problem.jacobian();
        let A = J.t().dot(&J);
        let mut a = -J.t().dot(&self.problem.error_vector());
        // TODO: error handling
        let _ = A.solvec_inplace(&mut a).unwrap();
        self.problem.update_params(&a);

        // 以降終了条件のチェック
        let squared_error = self.problem.squared_error();
        let delta_squared_error = (squared_error - prev_squared_error).abs();
        if delta_squared_error / squared_error < self.eps {
            self.converged = true;
            return;
        }
        let delta_x_norm = a.norm_l2();
        let x_norm = self.problem.params().norm_l2();
        if delta_x_norm / x_norm < self.eps {
            self.converged = true;
        }
    }
}

問題を解いてみる

https://ja.wikipedia.org/wiki/ガウス・ニュートン法 にある,Michaelis–Menten式の問題を解きます.
Wikipediaのページ見れば,Jacobi行列の式までご丁寧に書いてある(2019/12/23 現在)ので,詳細は割愛.

use ndarray::{arr1, stack, Array1, Array2, Axis};

use gauss_newton_rs::ndarray_ver::{GaussNewton, Problem};

struct SampleProblem {
    x: Array1<f32>,
    s: Array1<f32>,
    v: Array1<f32>,
}

impl SampleProblem {
    fn new() -> Self {
        let x = arr1(&[1.5, 1.5]);
        let s = arr1(&[0.038, 0.194, 0.425, 0.626, 1.253, 2.500, 3.740]);
        let v = arr1(&[0.050, 0.127, 0.094, 0.2122, 0.2729, 0.2665, 0.3317]);
        Self { x, s, v }
    }
}

impl Problem for SampleProblem {
    type T = f32;

    fn params(&self) -> &Array1<Self::T> {
        &self.x
    }

    fn update_params(&mut self, delta_params: &Array1<Self::T>) {
        self.x += delta_params;
    }

    fn error_vector(&self) -> Array1<Self::T> {
        let predicted = &self.s * self.x[0] / (&self.s + self.x[1]);
        &self.v - &predicted
    }

    fn jacobian(&self) -> Array2<Self::T> {
        let c1 = -&self.s / (&self.s + self.x[1]);
        let c2 = &self.s * self.x[0] / (&self.s + self.x[1]).mapv(|a| a.powi(2));
        let jacobian = stack![Axis(1), c1.insert_axis(Axis(1)), c2.insert_axis(Axis(1))];
        jacobian
    }
}

#[test]
fn test_gauss_newton_ndarray() {
    let problem = SampleProblem::new();
    let begin_squared_error = problem.squared_error();

    let mut gauss_newton = GaussNewton::new(problem);
    gauss_newton.solve_until_converged();

    let end_squared_error = gauss_newton.problem.squared_error();

    println!("{} iter(s)", gauss_newton.iters());
    assert!(begin_squared_error > end_squared_error);
    assert!(0.1 > end_squared_error);
    assert!(gauss_newton
        .problem
        .x
        .all_close(&arr1(&[0.362, 0.556]), 1e-3));
}

解けました.

nalgebraについても少し

nalgebraについてまったく触れないのもあれなので,一番大きな違いを感じたところだけ紹介します.
具体的には jacobian なんですが,ndarrayだと以下のように書くだけで勝手に要素ごとの演算になっていました(それでもPythonと違って,**がないので,mapvは使わざるを得ないのですが).

    fn jacobian(&self) -> Array2<Self::T> {
        let c1 = -&self.s / (&self.s + self.x[1]);
        let c2 = &self.s * self.x[0] / (&self.s + self.x[1]).mapv(|a| a.powi(2));
        let jacobian = stack![Axis(1), c1.insert_axis(Axis(1)), c2.insert_axis(Axis(1))];
        jacobian
    }

それがnalgebraだと,以下のような感じになります.これはこれで,わかりやすいっちゃわかりやすいですが.

    fn jacobian(&self) -> na::DMatrix<Self::T> {
        let c1 = self.s.map(|ss| -ss / (self.x[1] + ss));
        let c2 = self.s.map(|ss| self.x[0] * ss / (self.x[1] + ss).powi(2));
        let jacobian = na::DMatrix::<Self::T>::from_columns(&[c1, c2]);
        jacobian
    }

参考

  1. https://twitter.com/tonets/status/1097689173270511616 (ネタとしてはもう賞味期限切れか?)

9
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?