# References:
# * Fuchang Gao and Lixing Han (2010), Springer US. "Implementing the Nelder-Mead simplex algorithm with adaptive parameters" (doi:10.1007/s10589-010-9329-3)
# * Saša Singer and John Nelder (2009), Scholarpedia, 4(7):2928. "Nelder-Mead algorithm" (doi:10.4249/scholarpedia.2928)

# Adaptive Nelder-Mead Simplex (ANMS) algorithm
nelder_mead = function(f, x0, iterations=1000000, ftol=1.0e-8, xtol=1.0e-8){
    # diemnsion
    n = length(x0)
    if(n < 2){
        return ("multivariate function is needed")

    # parameters of transformations
    alpha = 1.0
    beta = 1.0 + 2 / n
    gamma = 0.75 - 1 / 2 / n
    delta = 1.0 - 1 / n

    # NOTE: Use @fcall macro to evaluate a function values
    #       so as to automatically increment a function call counter
    # number of function calls
    fcalls = 0
    fcall = function(x){
        fcalls <<- fcalls + 1
        return (f(x))

    # centroid except h-th vertex
    centroid = function (simplex, h){
        c = colMeans(simplex[-h,])
        return (c)

    # initialize a simplex and function values
    simplex = matrix(x0, nrow=n+1, ncol=n, byrow=T)
    fvalues = rep(fcall(x0), n+1)
    for (i in 1:n) {
        tau = ifelse(x0[i] == 0.0, 0.00025, 0.05 * x0[i])
        x = x0
        x[i] = x[i] + tau
        simplex[i+1,] = x
        fvalues[i+1] = fcall(x)
    ord = order(fvalues)

    # stopping criteria
    iter = 0  # number of iterations
    domconv = FALSE  # domain convergence
    fvalconv = FALSE  # function-value convergence

    # centroid cache
    c = centroid(simplex, ord[n+1])

    while((iter < iterations) && !(fvalconv && domconv)){
        # highest, second highest, and lowest indices, respectively
        h = ord[n+1]
        s = ord[n]
        l = ord[1]

        xh = simplex[h,]
        fh = fvalues[h]
        fs = fvalues[s]
        xl = simplex[l,]
        fl = fvalues[l]

        xr = c + alpha * (c - xh)
        fr = fcall(xr)
        doshrink = FALSE

        if (fr < fl) { # <= fs
            # expand
            xe = c + beta * (xr - c)
            fe = fcall(xe)
            if (fe < fr) {
                accept = list(xe, fe)
            } else {
                accept = list(xr, fr)
        } else if (fr < fs) {
            # reflect
            accept = list(xr, fr)
        } else { # fs <= fr
            # contract
            if (fr < fh) {
                # outside
                xc = c + gamma * (xr - c)
                fc = fcall(xc)
                if (fc <= fr ) {
                    accept = list(xc, fc)
                } else {
                    doshrink = TRUE
            } else {
                # inside
                xc = c - gamma * (xr - c)
                fc = fcall(xc)
                if (fc < fh) {
                    accept = list(xc, fc)
                } else {
                    doshrink = TRUE

            # shrinkage almost never happen in practice
            if (doshrink) {
                # shrink
                for (i in 2:(n+1)) {
                    o = ord[i]
                    xi = xl + delta * (simplex[o,] - xl)
                    simplex[o,] = xi
                    fvalues[o] = fcall(xi)

        # update simplex, function values and centroid cache
        if (doshrink) {
            ord = order(fvalues)
            c = centroid(simplex, ord[n+1])
        } else {
            x = accept[[1]]
            fvalue = accept[[2]]

            # insert the new function value into an ordered position
            simplex[h,] = x
            fvalues[h] = fvalue
            for (i in (n+1):2) {
                if (fvalues[ord[i-1]] > fvalues[ord[i]]) {
                    tmp = ord[i-1]
                    ord[i-1] = ord[i]
                    ord[i] = tmp
                } else {

            # add the new vertex, and subtract the highest vertex
            h = ord[n+1]
            xh = simplex[h,]
            c = c + (x - xh) / n

        l = ord[1]
        xl = simplex[l,]
        fl = fvalues[l]

        # check convergence
        fvalconv = TRUE
        for (i in 2:(n+1)) {
            if (abs(fvalues[i] - fl) > ftol) {
                fvalconv = FALSE
        domconv = TRUE
        for (i in 2:(n+1)) for (j in 1:n) {
            if (abs(simplex[i, j] - xl[j]) > xtol) {
                domconv = FALSE

        iter = iter + 1

    # return the lowest vertex (or the centroid of the simplex) and the function value
    c = colMeans(simplex)
    fcent = fcall(c)
    if (fcent < fvalues[ord[1]]) {
        return (list(par=c, value=fcent, counts=fcalls))
    return (list(par=simplex[ord[1],], value=fvalues[ord[1]], counts=fcalls))
# rosenbrock function
tol = 1e-8
a = 1.0
b = 100.0
f = function(x){(a - x[1])^2 + b*(x[2] - x[1]^2)^2}
x0 = c(0.0, 0.0)
(ret = nelder_mead(f, x0))
x = ret[[1]]
fmin = ret[[2]]
sqrt(sum((x - c(a, a))^2)) < tol
abs(fmin - 0.0) < tol

# quadratic function
tol = 1e-8
f = function(x){sum(x^2)}
t(sapply(2:10, function(n) {
    x0 = rep(1, n)
    ret = nelder_mead(f, x0)
    x = ret[[1]]
    fmin = ret[[2]]
    list(n=n, x=(sqrt(sum(x*x)) < tol), f=(abs(fmin) < tol), fcalls=ret[[3]])

