1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Nelder-Mead法をC++に移植

Posted at

目的関数の勾配を使用しないで最適化するアルゴリズム。
JuliaによるNelder-Meadアルゴリズムの実装
JuliaからC++に移植してみた。簡略化してヘッダーに実装も記述。
std::vectorを拡張してみた(その後)」で作成したmy_vector.hを使用して、std::vectorの四則演算を実行。

#include <algorithm>
#include "my_vector.h"

// Adaptive Nelder - Mead Simplex(ANMS) algorithm

typedef double optimfn(int n, const double *x0, void *ex);

class nelder_mead {
private:
    class vertex {
    public:
        std::vector<double> x;
        double v;
    public:
        vertex(const std::vector<double>&x, double v) : x(x), v(v) {}
        static bool compare(const vertex *p1, const vertex *p2) {
            return (p1->v < p2->v);
        }
    };
private:
    std::vector<vertex*> simplex;
    void resize(int m) {
        if (!simplex.empty()) {
            for (auto e : simplex) delete e;
            simplex.resize(0);
        }
        if (m > 0) simplex.resize(m, NULL);
    }
    std::vector<double> centroid(int h) {
        int n = simplex.size() - 1;
        std::vector<double> c(n, 0.0);
        int m = 0;
        for (int i = 0; i < n + 1; i++) {
            if (i == h) continue;
            c += simplex[i]->x;
            m++;
        }
        return c / m;
    }
public:
    int iterations = 0;
    int fncount = 0;
    std::vector<double> xout;
    double fmin = 0;
public:
    nelder_mead() = default;
    nelder_mead(const nelder_mead&) = delete;
    nelder_mead& operator=(const nelder_mead&) = delete;
    ~nelder_mead() {
        resize(0);
    }
public:
    void minimize(optimfn fminfn, void *ex, const std::vector<double>& x0,
        int maxit, double ftol, double xtol) {

        // parameters of transformations
        int n = x0.size();
        double alpha = 1.0;
        double beta = 1.0 + 2.0 / n;
        double gamma = 0.75 - 1.0 / 2.0 / n;
        double delta = 1.0 - 1.0 / n;

        // number of function calls
        fncount = 0;
        auto fcall = [&](const std::vector<double>& x, void *ex) {
            fncount++;
            return fminfn(x.size(), x.data(), ex);
        };

        // initialize a simplex and function values
        resize(n + 1);
        double fvalue = fcall(x0, ex);
        for (auto& e : simplex) {
            e = new vertex(x0, fvalue);
        }
        for (int i = 1; i < n + 1; i++) {
            double tau = (x0[i - 1] == 0.0 ? 0.00025 : 0.05 * x0[i - 1]);
            simplex[i]->x[i - 1] += tau;
            simplex[i]->v = fcall(simplex[i]->x, ex);
        }
        std::sort(simplex.begin(), simplex.end(), vertex::compare);

        // centroid cache
        std::vector<double> c = centroid(n);

        // lowest index
        std::vector<double> xl = simplex[0]->x;
        double fl = simplex[0]->v;

        // stopping criteria
        iterations = 0;  // number of iterations
        bool domconv = false;  // domain convergence
        bool fvalconv = false;  // function-value convergence

        while ((iterations < maxit) && !(fvalconv && domconv)) {
            // highest and second highest indices
            std::vector<double> xh = simplex[n]->x;
            double fh = simplex[n]->v;
            double fs = simplex[n - 1]->v;

            // reflect
            std::vector<double> xr = c + (c - xh) * alpha;
            double fr = fcall(xr, ex);
            bool doshrink = false;
            std::vector<double> x;
            double fvalue = 0;

            if (fr < fl) { // <= fs
                // expand
                std::vector<double> xe = c + (xr - c) * beta;
                double fe = fcall(xe, ex);
                if (fe < fr) {
                    x = xe;
                    fvalue = fe;
                }
                else {
                    x = xr;
                    fvalue = fr;
                }
            }
            else if (fr < fs) {
                x = xr;
                fvalue = fr;
            }
            else { // fs <= fr
                // contract
                if (fr < fh) {
                    // outside
                    std::vector<double> xc = c + (xr - c) * gamma;
                    double fc = fcall(xc, ex);
                    if (fc <= fr) {
                        x = xc;
                        fvalue = fc;
                    }
                    else {
                        doshrink = true;
                    }
                }
                else {
                    // inside
                    std::vector<double> xc = c - (xr - c) * gamma;
                    double fc = fcall(xc, ex);
                    if (fc < fh) {
                        x = xc;
                        fvalue = fc;
                    }
                    else {
                        doshrink = true;
                    }
                }
            }
            // update simplex, function values and centroid cache
            // shrinkage almost never happen in practice
            if (doshrink) {
                // shrink
                for (int i = 1; i < n + 1; i++) {
                    simplex[i]->x = xl + (simplex[i]->x - xl) * delta;
                    simplex[i]->v = fcall(simplex[i]->x, ex);
                }
                std::sort(simplex.begin(), simplex.end(), vertex::compare);
                c = centroid(n);
            }
            else {
                // insert the new function value into an ordered position
                simplex[n]->x = x;
                simplex[n]->v = fvalue;
                for (int i = n; i > 0; i--) {
                    if (simplex[i - 1]->v > simplex[i]->v) {
                        vertex *tmp = simplex[i - 1];
                        simplex[i - 1] = simplex[i];
                        simplex[i] = tmp;
                    }
                    else {
                        break;
                    }
                }
                // add the new vertex, and subtract the highest vertex
                xh = simplex[n]->x;
                c += (x - xh) / n;
            }

            xl = simplex[0]->x;
            fl = simplex[0]->v;

            // check convergence
            fvalconv = true;
            for (int i = 1; i < n + 1; i++) {
                if (std::abs(simplex[i]->v - fl) > ftol) {
                    fvalconv = false;
                    break;
                }
            }
            domconv = true;
            for (int i = 1; i < n + 1; i++) {
                for (int j = 0; j < n; j++) {
                    if (std::abs(simplex[i]->x[j] - xl[j]) > xtol) {
                        domconv = false;
                        break;
                    }
                }
                if (!domconv) break;
            }

            iterations++;
        }

        // return the lowest vertex (or the centroid of the simplex) and the function value
        c = centroid(-1);
        double fcent = fcall(c, ex);
        if (fcent < fl) {
            xout = c;
            fmin = fcent;
        }
        else {
            xout = xl;
            fmin = fl;
        }
    }
};

使用例

double rosenbrock(int n, const double *x, void *ex) {
    double *params = (double*)ex;
    double a = params[0];
    double b = params[1];
    double ret = 0;
    for (int i = 0; i < n - 1; i += 2) {
        double s = a - x[i];
        double t = x[i + 1] - x[i] * x[i];
        ret += s * s + b * t * t;
    }
    if (n % 2 == 1) {
        double s = a - x[n - 1];
        ret += s * s;
    }
    return ret;
}

double quadratic(int n, const double *x, void *ex) {
    double ret = 0;
    for (int i = 0; i < n; i++) {
        ret += x[i] * x[i];
    }
    return ret;
}

void test_anms() {
    std::cout << "test_anms" << std::endl;
    int maxit = 1000000;
    double ftol = 1.0e-8;
    double xtol = 1.0e-8;
    nelder_mead obj;

    for (int n = 2; n <= 10; n++) {
        double a = 1.0;
        double b = 100.0;
        std::vector<double> params = { a, b };
        std::vector<double> x0(n, 0);
        std::vector<double> x1(n, a);
        for (int i = 0; i < n - 1; i += 2) {
            x1[i + 1] = a * a;
        }
        void *ex = params.data();
        obj.minimize(rosenbrock, ex, x0, maxit, ftol, xtol);
        std::cout << "rosenbrock,";
        std::cout << n << ",";
        std::cout << std::boolalpha << (norm(obj.xout - x1) < xtol) << ",";
        std::cout << std::boolalpha << (std::abs(obj.fmin) < ftol) << ",";
        std::cout << obj.fncount << std::endl;
    }

    for (int n = 2; n <= 10; n++) {
        std::vector<double> x0(n, 1.0);
        std::vector<double> x1(n, 0.0);
        void *ex = NULL;
        obj.minimize(quadratic, ex, x0, maxit, ftol, xtol);
        std::cout << "quadratic,";
        std::cout << n << ",";
        std::cout << std::boolalpha << (norm(obj.xout - x1) < xtol) << ",";
        std::cout << std::boolalpha << (std::abs(obj.fmin) < ftol) << ",";
        std::cout << obj.fncount << std::endl;
    }
}

実行結果

test_anms
rosenbrock,2,true,true,207
rosenbrock,3,true,true,472
rosenbrock,4,true,true,790
rosenbrock,5,true,true,1155
rosenbrock,6,true,true,1828
rosenbrock,7,true,true,2012
rosenbrock,8,true,true,2795
rosenbrock,9,true,true,3293
rosenbrock,10,true,true,5371
quadratic,2,true,true,128
quadratic,3,true,true,285
quadratic,4,true,true,424
quadratic,5,true,true,593
quadratic,6,true,true,743
quadratic,7,true,true,957
quadratic,8,true,true,1133
quadratic,9,true,true,1321
quadratic,10,true,true,1494

参考文献

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?