LoginSignup
1
1

More than 5 years have passed since last update.

Nelder-Mead法をRに移植

Last updated at Posted at 2017-11-11

面白そうなアルゴリズムを見つけた。目的関数の勾配を使用しないで最適化するアルゴリズム。
JuliaによるNelder-Meadアルゴリズムの実装

JuliaからRに移植してみた。

# References:
# * Fuchang Gao and Lixing Han (2010), Springer US. "Implementing the Nelder-Mead simplex algorithm with adaptive parameters" (doi:10.1007/s10589-010-9329-3)
# * Saša Singer and John Nelder (2009), Scholarpedia, 4(7):2928. "Nelder-Mead algorithm" (doi:10.4249/scholarpedia.2928)

# Adaptive Nelder-Mead Simplex (ANMS) algorithm
nelder_mead = function(f, x0, iterations=1000000, ftol=1.0e-8, xtol=1.0e-8){
    # diemnsion
    n = length(x0)
    if(n < 2){
        return ("multivariate function is needed")
    }

    # parameters of transformations
    alpha = 1.0
    beta = 1.0 + 2 / n
    gamma = 0.75 - 1 / 2 / n
    delta = 1.0 - 1 / n

    # NOTE: Use @fcall macro to evaluate a function values
    #       so as to automatically increment a function call counter
    # number of function calls
    fcalls = 0
    fcall = function(x){
        fcalls <<- fcalls + 1
        return (f(x))
    }

    # centroid except h-th vertex
    centroid = function (simplex, h){
        c = colMeans(simplex[-h,])
        return (c)
    }

    # initialize a simplex and function values
    simplex = matrix(x0, nrow=n+1, ncol=n, byrow=T)
    fvalues = rep(fcall(x0), n+1)
    for (i in 1:n) {
        tau = ifelse(x0[i] == 0.0, 0.00025, 0.05 * x0[i])
        x = x0
        x[i] = x[i] + tau
        simplex[i+1,] = x
        fvalues[i+1] = fcall(x)
    }
    ord = order(fvalues)

    # stopping criteria
    iter = 0  # number of iterations
    domconv = FALSE  # domain convergence
    fvalconv = FALSE  # function-value convergence

    # centroid cache
    c = centroid(simplex, ord[n+1])

    while((iter < iterations) && !(fvalconv && domconv)){
        # highest, second highest, and lowest indices, respectively
        h = ord[n+1]
        s = ord[n]
        l = ord[1]

        xh = simplex[h,]
        fh = fvalues[h]
        fs = fvalues[s]
        xl = simplex[l,]
        fl = fvalues[l]

        xr = c + alpha * (c - xh)
        fr = fcall(xr)
        doshrink = FALSE

        if (fr < fl) { # <= fs
            # expand
            xe = c + beta * (xr - c)
            fe = fcall(xe)
            if (fe < fr) {
                accept = list(xe, fe)
            } else {
                accept = list(xr, fr)
            }
        } else if (fr < fs) {
            # reflect
            accept = list(xr, fr)
        } else { # fs <= fr
            # contract
            if (fr < fh) {
                # outside
                xc = c + gamma * (xr - c)
                fc = fcall(xc)
                if (fc <= fr ) {
                    accept = list(xc, fc)
                } else {
                    doshrink = TRUE
                }
            } else {
                # inside
                xc = c - gamma * (xr - c)
                fc = fcall(xc)
                if (fc < fh) {
                    accept = list(xc, fc)
                } else {
                    doshrink = TRUE
                }
            }

            # shrinkage almost never happen in practice
            if (doshrink) {
                # shrink
                for (i in 2:(n+1)) {
                    o = ord[i]
                    xi = xl + delta * (simplex[o,] - xl)
                    simplex[o,] = xi
                    fvalues[o] = fcall(xi)
                }
            }
        }

        # update simplex, function values and centroid cache
        if (doshrink) {
            ord = order(fvalues)
            c = centroid(simplex, ord[n+1])
        } else {
            x = accept[[1]]
            fvalue = accept[[2]]

            # insert the new function value into an ordered position
            simplex[h,] = x
            fvalues[h] = fvalue
            for (i in (n+1):2) {
                if (fvalues[ord[i-1]] > fvalues[ord[i]]) {
                    tmp = ord[i-1]
                    ord[i-1] = ord[i]
                    ord[i] = tmp
                } else {
                    break
                }
            }

            # add the new vertex, and subtract the highest vertex
            h = ord[n+1]
            xh = simplex[h,]
            c = c + (x - xh) / n
        }

        l = ord[1]
        xl = simplex[l,]
        fl = fvalues[l]

        # check convergence
        fvalconv = TRUE
        for (i in 2:(n+1)) {
            if (abs(fvalues[i] - fl) > ftol) {
                fvalconv = FALSE
                break
            }
        }
        domconv = TRUE
        for (i in 2:(n+1)) for (j in 1:n) {
            if (abs(simplex[i, j] - xl[j]) > xtol) {
                domconv = FALSE
                break
            }
        }

        iter = iter + 1
    }

    # return the lowest vertex (or the centroid of the simplex) and the function value
    c = colMeans(simplex)
    fcent = fcall(c)
    if (fcent < fvalues[ord[1]]) {
        return (list(par=c, value=fcent, counts=fcalls))
    }
    return (list(par=simplex[ord[1],], value=fvalues[ord[1]], counts=fcalls))
}
# rosenbrock function
tol = 1e-8
a = 1.0
b = 100.0
f = function(x){(a - x[1])^2 + b*(x[2] - x[1]^2)^2}
x0 = c(0.0, 0.0)
(ret = nelder_mead(f, x0))
x = ret[[1]]
fmin = ret[[2]]
sqrt(sum((x - c(a, a))^2)) < tol
abs(fmin - 0.0) < tol

# quadratic function
tol = 1e-8
f = function(x){sum(x^2)}
t(sapply(2:10, function(n) {
    x0 = rep(1, n)
    ret = nelder_mead(f, x0)
    x = ret[[1]]
    fmin = ret[[2]]
    list(n=n, x=(sqrt(sum(x*x)) < tol), f=(abs(fmin) < tol), fcalls=ret[[3]])
}))
1
1
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1