面白そうなアルゴリズムを見つけた。目的関数の勾配を使用しないで最適化するアルゴリズム。
JuliaによるNelder-Meadアルゴリズムの実装
JuliaからRに移植してみた。
# References:
# * Fuchang Gao and Lixing Han (2010), Springer US. "Implementing the Nelder-Mead simplex algorithm with adaptive parameters" (doi:10.1007/s10589-010-9329-3)
# * Saša Singer and John Nelder (2009), Scholarpedia, 4(7):2928. "Nelder-Mead algorithm" (doi:10.4249/scholarpedia.2928)
# Adaptive Nelder-Mead Simplex (ANMS) algorithm
nelder_mead = function(f, x0, iterations=1000000, ftol=1.0e-8, xtol=1.0e-8){
# diemnsion
n = length(x0)
if(n < 2){
return ("multivariate function is needed")
}
# parameters of transformations
alpha = 1.0
beta = 1.0 + 2 / n
gamma = 0.75 - 1 / 2 / n
delta = 1.0 - 1 / n
# NOTE: Use @fcall macro to evaluate a function values
# so as to automatically increment a function call counter
# number of function calls
fcalls = 0
fcall = function(x){
fcalls <<- fcalls + 1
return (f(x))
}
# centroid except h-th vertex
centroid = function (simplex, h){
c = colMeans(simplex[-h,])
return (c)
}
# initialize a simplex and function values
simplex = matrix(x0, nrow=n+1, ncol=n, byrow=T)
fvalues = rep(fcall(x0), n+1)
for (i in 1:n) {
tau = ifelse(x0[i] == 0.0, 0.00025, 0.05 * x0[i])
x = x0
x[i] = x[i] + tau
simplex[i+1,] = x
fvalues[i+1] = fcall(x)
}
ord = order(fvalues)
# stopping criteria
iter = 0 # number of iterations
domconv = FALSE # domain convergence
fvalconv = FALSE # function-value convergence
# centroid cache
c = centroid(simplex, ord[n+1])
while((iter < iterations) && !(fvalconv && domconv)){
# highest, second highest, and lowest indices, respectively
h = ord[n+1]
s = ord[n]
l = ord[1]
xh = simplex[h,]
fh = fvalues[h]
fs = fvalues[s]
xl = simplex[l,]
fl = fvalues[l]
xr = c + alpha * (c - xh)
fr = fcall(xr)
doshrink = FALSE
if (fr < fl) { # <= fs
# expand
xe = c + beta * (xr - c)
fe = fcall(xe)
if (fe < fr) {
accept = list(xe, fe)
} else {
accept = list(xr, fr)
}
} else if (fr < fs) {
# reflect
accept = list(xr, fr)
} else { # fs <= fr
# contract
if (fr < fh) {
# outside
xc = c + gamma * (xr - c)
fc = fcall(xc)
if (fc <= fr ) {
accept = list(xc, fc)
} else {
doshrink = TRUE
}
} else {
# inside
xc = c - gamma * (xr - c)
fc = fcall(xc)
if (fc < fh) {
accept = list(xc, fc)
} else {
doshrink = TRUE
}
}
# shrinkage almost never happen in practice
if (doshrink) {
# shrink
for (i in 2:(n+1)) {
o = ord[i]
xi = xl + delta * (simplex[o,] - xl)
simplex[o,] = xi
fvalues[o] = fcall(xi)
}
}
}
# update simplex, function values and centroid cache
if (doshrink) {
ord = order(fvalues)
c = centroid(simplex, ord[n+1])
} else {
x = accept[[1]]
fvalue = accept[[2]]
# insert the new function value into an ordered position
simplex[h,] = x
fvalues[h] = fvalue
for (i in (n+1):2) {
if (fvalues[ord[i-1]] > fvalues[ord[i]]) {
tmp = ord[i-1]
ord[i-1] = ord[i]
ord[i] = tmp
} else {
break
}
}
# add the new vertex, and subtract the highest vertex
h = ord[n+1]
xh = simplex[h,]
c = c + (x - xh) / n
}
l = ord[1]
xl = simplex[l,]
fl = fvalues[l]
# check convergence
fvalconv = TRUE
for (i in 2:(n+1)) {
if (abs(fvalues[i] - fl) > ftol) {
fvalconv = FALSE
break
}
}
domconv = TRUE
for (i in 2:(n+1)) for (j in 1:n) {
if (abs(simplex[i, j] - xl[j]) > xtol) {
domconv = FALSE
break
}
}
iter = iter + 1
}
# return the lowest vertex (or the centroid of the simplex) and the function value
c = colMeans(simplex)
fcent = fcall(c)
if (fcent < fvalues[ord[1]]) {
return (list(par=c, value=fcent, counts=fcalls))
}
return (list(par=simplex[ord[1],], value=fvalues[ord[1]], counts=fcalls))
}
# rosenbrock function
tol = 1e-8
a = 1.0
b = 100.0
f = function(x){(a - x[1])^2 + b*(x[2] - x[1]^2)^2}
x0 = c(0.0, 0.0)
(ret = nelder_mead(f, x0))
x = ret[[1]]
fmin = ret[[2]]
sqrt(sum((x - c(a, a))^2)) < tol
abs(fmin - 0.0) < tol
# quadratic function
tol = 1e-8
f = function(x){sum(x^2)}
t(sapply(2:10, function(n) {
x0 = rep(1, n)
ret = nelder_mead(f, x0)
x = ret[[1]]
fmin = ret[[2]]
list(n=n, x=(sqrt(sum(x*x)) < tol), f=(abs(fmin) < tol), fcalls=ret[[3]])
}))