LoginSignup
2
2

More than 5 years have passed since last update.

Apache Commons Math3のLevenbergMarquardtOptimizerを利用した2次元ガウス分布のフィッティング

Last updated at Posted at 2015-01-07

はじめに

Apache Commons Math3のLM法によるフィッティングができるようになったので、2次元ガウス分布のフィッティングを実装してみる。ImageJなどから得た画像データをフィッティングすることを目的としている。基本的には、2次関数のときの場合と同じだが、ArrayListを使うのをやめたりと少し変更している。

2 dimensional Gaussian Functionの実装

QuadraticFunctionの場合のメンバ変数とコンストラクタは

    // Member variables
    List<Double> x;
    List<Double> y;

    /**
     * Constructor of QuadraticFunction
     * @param x input data
     * @param y target data
     */
    public QuadraticFunction(List<Double> x, List<Double> y) {
        this.x=x;
        this.y=y;
    }

はこんな感じで、入力データはArrayList型として受け取っている。ここで、変数xはx座標のデータだが、QuadraticFunctionProblemでの入力の仕方を見てみると、

        qf.addPoint(1, 34.234064369);
        qf.addPoint(2, 68.2681162306);
        ...

こんな感じで、(x,y)の値を入力しているが、xはぶっちゃけデータに対して1ずつインクリメントしていくだけなので、わざわざ与える必要はないと考えた(空間スケール情報は、フィッティングされたパラメータに対して与えればよいし)。y座標データについては、QuadraticFunctionではあとでcalculateTarget()メソッドの返り値として利用している。このメソッドは

    public double[] calculateTarget() {
        double[] target = new double[y.size()];
        for (int i = 0; i < y.size(); i++) {
            target[i] = y.get(i).doubleValue();
        }
        return target;
    }

となっており、ただdouble[]を返すだけであり、さらにこの後直接yの値(ターゲットデータ)を利用することはない。

なので、以下のように割り切ってみた。

public class TwoDGaussianFunction {

    // Member variables
    int x_width; // width of data
    int data_size; // data size

    /**
     * @param data  input data 
     * @param x_width   input data width 
     */
    public TwoDGaussianFunction(int x_width,int data_size) {

        this.x_width = x_width;
        this.data_size = data_size;

    }

コンストラクと時に必要な引数として渡しているのは、データの横幅とサイズだけである。ではf(x,y)の値はどうやって渡すかというと、メインのクラスの方(呼び出す方)で、あらかじめ1次元配列として与えておいて,

        //entry the data
        double[] inputdata = {
                0  ,12 ,25 ,12 ,0  ,
                12 ,89 ,153,89 ,12 ,
                25 ,153,255,153,25 ,
                12 ,89 ,153,89 ,12 ,
                0  ,12 ,25 ,12 ,0  ,
        };

     ...

        //set target data
        lsb.target(inputdata);

と直接渡すことにした。ここで、lsb

        //prepare construction of LeastSquresProblem by builder
        LeastSquaresBuilder lsb = new LeastSquaresBuilder();

である。

2次元ガウス関数

次にモデル関数の実装を行う。2次元ガウス関数は一般に次のような形をしている。

f(x,y) = A/sqrt(2*pi*s_x^2*s_y^2)*exp(-(x-x_m)^2/(2*s_x^2))*exp(-(y-y_m)^2/(2*s_y^2))+offset

パラメータは、それぞれA; 振幅、x_m, y_m; x,yの平均値, s_x, s_y; x,yの標準偏差, offset; オフセット値である。これをMultvariateVectorFunction()内のメソッドvalue内で定義する必要がある。v0~v1までが、実際に実装する際のパラメータである (double[] vの引数)

public double[] value(double[] v)
        throws IllegalArgumentException {
    double[] values = new double[data_size];
    // pre-calculation
    double v3v3 = v[3]*v[3];
    double v4v4 = v[4]*v[4];
    double sqrt_twopiv3v4 = Math.sqrt(2*Math.PI*v3v3*v4v4); 
    for (int i = 0; i < values.length; ++i) {
        // parameters for x,y positioning
        int xi = i % x_width;   
        int yi = i / x_width;

        values[i] = v[0]/sqrt_twopiv3v4
                    *Math.exp(-(xi-v[1])*(xi-v[1])/(2*v3v3))
                    *Math.exp(-(yi-v[2])*(yi-v[2])/(2*v4v4))
                    +v[5];
                }               
                return values;
            }

double v3v3, v4v4, sqrt_wtopiv3v4は定数なのでつぎのForループの前であらかじめ計算させている。また、ターゲットデータを1次元配列として扱っているので、2次元ガウス関数も実際には1次元関数として扱われる。そのため、x,yの値を渡すために、媒介変数としてxi, yiを導入した。あとは、先ほどしめした関数型を実装するのみである。

2次元ガウス関数のjacobian

次に、MultivariateMatrixFunction()の中で実装するメソッドjacobian()を記述する。

private double[][] jacobian(double[] v) {
    double[][] jacobian = new double[data_size][6];              
    double v3v3 = v[3]*v[3];
    double v4v4 = v[4]*v[4];
    double sqrt_twopiv3v4 = Math.sqrt(2*Math.PI*v3v3*v4v4);
    for (int i = 0; i < jacobian.length; ++i) {
        // parameters for x,y positioning
        int xi = i % x_width;
        int yi = i / x_width;
        double exp_x = Math.exp(-(xi-v[1])*(xi-v[1])/(2*v3v3));
        double exp_y = Math.exp(-(yi-v[2])*(yi-v[2])/(2*v4v4));
        //partial differentiations were calculated by using Maxima
        jacobian[i][0] = exp_x*exp_y/sqrt_twopiv3v4;        //df(x,y)/dv0
        jacobian[i][1] = v[0]*(xi-v[1])/v3v3*jacobian[i][0];                        //df(x,y)/dv1
        jacobian[i][2] = v[0]*(yi-v[2])/v4v4*jacobian[i][0];                        //df(x,y)/dv2
        jacobian[i][3] = jacobian[i][1]*(xi-v[1])/v[3]-v[0]*jacobian[i][0]/v[3];    //df(x,y)/dv3
        jacobian[i][4] = jacobian[i][2]*(yi-v[2])/v[4]-v[0]*jacobian[i][0]/v[4];    //df(x,y)/dv4
        jacobian[i][5] = 1;                                                         //df(x,y)/dv5
    }
    return jacobian;
}  

jacobian[i]jは2次元ガウス関数をパラメータv0~v5で偏微分して得られる導函数である。偏微分の計算はなかなか大変なので、maximaを使って計算を行った(関数f(x,y)をmaxima上で定義して、diff()で偏微分しただけ)。各パラメータで偏微分した形をみると、他のパラメータで偏微分した函数と共通項があるので、無駄な計算をしないように前の結果を利用する形で記述をしている。

フィッティング例

これでほぼ実装は終わりだ。あとはQuadrticProblemと同じように実装すればよい。前に示したテストデータを試しにフィッティングしてみると以下の結果を得た。

v0: 648.008445344128
v1: 2.0
v2: 2.0
v3: 0.9870954221255059
v4: 0.9870954221255059
v5: -7.151657318492871
Iteration number: 9
Evaluation number: 11

v1,v2は明らかに2であったので、フィッティングはうまくいっているようだ!v0の値が255より大きいのは、分母にsqrt(2*pi*v3^2*v4^2)があるためである。

サンプルコード

ここにサンプルコードをアップした。

まとめ

2次元ガウス分布によるフィッティングも、1次元関数とみなすことで比較的容易に実装することができた。つまり、SimpleCurveFitterを使うことも多分できるだろうがiterationなどの情報を得ることができないので、やはり少し手間をかけてでもLevenbergMarquardtOptimizerを使ったほうが良いだろう。

追記 (2015.01.08)

GitにアップしたコードはDefalut packageになっていたので、pta2.track.tdgaussianというパッケージ名を与えたコードも追加してみた。また、クラス名TwoDGaussProblemとしてコンストラクタを与え、外部から利用しやすいようにしてみた。

/**
* @param data   input data
* @param newStart   initial values
* @param data_width data width
* @param optim_param    [0]:maxEvaluation, [1]:maxIteration
*/
public TwoDGaussProblem(double[] data, double[] newStart, int data_width, int[] optim_param) 

メソッドfit2DGauss()を実行すると、Optimum型としてフィッティングしたデータを返す。

TwoDGaussProblem tdgp = new TwoDGaussProblem(inputdata, newStart, 5, new int[] {1000,100});

try{
    //do LevenbergMarquardt optimization and get optimized parameters
    Optimum opt = tdgp.fit2dGauss();
    final double[] optimalValues = opt.getPoint().toArray();
    //output data
    System.out.println("v0: " + optimalValues[0]);
    System.out.println("v1: " + optimalValues[1]);
                 ...
2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2