Skip to content

Commit b9f01b1

Browse files
committed
Implemented optimx wrapper
1 parent d55394b commit b9f01b1

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

R/optimize.R

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#' Optimize an Objective Function using Various Optimization Methods using optimx
2+
#'
3+
#' This function extends the existing optimization capabilities by integrating the `optimx` package,
4+
#' allowing the use of various optimization algorithms such as "L-BFGS-B", "BFGS", "Nelder-Mead", etc.
5+
#'
6+
#' @param objfn A function that computes and returns a list with components `value`, `gradient`, and `hessian`.
7+
#' This represents the objective function to be minimized.
8+
#' @param parinit A numeric vector of initial parameter values.
9+
#' @param method A character string specifying the optimization method. Defaults to "L-BFGS-B".
10+
#' Available methods include those supported by the `optimx` package, such as "BFGS",
11+
#' "Nelder-Mead", "L-BFGS-B", etc.
12+
#' @param lower A numeric vector of lower bounds for the parameters (used only by methods that support
13+
#' box constraints, e.g., "L-BFGS-B"). Defaults to `-Inf`.
14+
#' @param upper A numeric vector of upper bounds for the parameters (used only by methods that support
15+
#' box constraints, e.g., "L-BFGS-B"). Defaults to `Inf`.
16+
#' @param control A list of control parameters to pass to the optimization algorithm.
17+
#' @param ... Additional arguments to pass to the objective function.
18+
#'
19+
#' @return A list containing:
20+
#' - `value`: The value of the objective function at the optimum.
21+
#' - `gradient`: The gradient at the optimum.
22+
#' - `hessian`: The Hessian at the optimum.
23+
#' - `argument`: The optimized parameters.
24+
#' - `converged`: Logical indicating if the optimizer converged.
25+
#' - `iterations`: The number of function evaluations.
26+
#' @import optimx
27+
#' @export
28+
optimize <- function(objfn, parinit, method = "L-BFGS-B", lower = -Inf, upper = Inf, control = list(), ...) {
29+
30+
# Sanitize the initial parameters.
31+
sanePars <- sanitizePars(parinit, list(...)$fixed)
32+
parinit <- sanePars$pars
33+
34+
# Ensure lower/upper bounds are vectors with proper names.
35+
if (length(lower) == 1) {
36+
lower <- rep(lower, length(parinit))
37+
names(lower) <- names(parinit)
38+
} else if (is.null(names(lower))) {
39+
names(lower) <- names(parinit)
40+
}
41+
42+
if (length(upper) == 1) {
43+
upper <- rep(upper, length(parinit))
44+
names(upper) <- names(parinit)
45+
} else if (is.null(names(upper))) {
46+
names(upper) <- names(parinit)
47+
}
48+
49+
# Initialize cache to store previously evaluated parameter results.
50+
cache <- new.env(hash = TRUE, parent = emptyenv())
51+
52+
# Helper function to generate a unique key for each parameter set.
53+
generate_key <- function(par) paste0(format(par, digits = 8), collapse = ",")
54+
55+
# Define the function wrappers required by optimx with error handling and caching.
56+
fn <- function(par, ...) {
57+
names(par) <- names(parinit)
58+
key <- generate_key(par)
59+
60+
if (exists(key, envir = cache)) {
61+
result <- get(key, envir = cache)
62+
} else {
63+
result <- try(objfn(par, ...), silent = TRUE)
64+
assign(key, result, envir = cache)
65+
}
66+
67+
if (inherits(result, "try-error") || !is.finite(result$value)) {
68+
return(1e10) # Large penalty for solver failures
69+
}
70+
result$value
71+
}
72+
73+
gr <- function(par, ...) {
74+
names(par) <- names(parinit)
75+
key <- generate_key(par)
76+
77+
if (exists(key, envir = cache)) {
78+
result <- get(key, envir = cache)
79+
} else {
80+
result <- try(objfn(par, ...), silent = TRUE)
81+
assign(key, result, envir = cache)
82+
}
83+
84+
if (inherits(result, "try-error") || !all(is.finite(result$gradient))) {
85+
return(rep(0, length(par))) # Return zero gradient on failure
86+
}
87+
unname(result$gradient)
88+
}
89+
90+
hess <- function(par, ...) {
91+
names(par) <- names(parinit)
92+
key <- generate_key(par)
93+
94+
if (exists(key, envir = cache)) {
95+
result <- get(key, envir = cache)
96+
} else {
97+
result <- try(objfn(par, ...), silent = TRUE)
98+
assign(key, result, envir = cache)
99+
}
100+
101+
if (inherits(result, "try-error") || !all(is.finite(result$hessian))) {
102+
return(diag(length(par))) # Return identity matrix if Hessian fails
103+
}
104+
unname(result$hessian)
105+
}
106+
107+
# Perform the optimization.
108+
optim_result <- optimx::optimr(
109+
par = as.numeric(parinit),
110+
fn = fn,
111+
gr = gr,
112+
hess = hess,
113+
method = method,
114+
lower = unname(lower),
115+
upper = unname(upper),
116+
control = control,
117+
...
118+
)
119+
120+
# Extract the optimized parameters.
121+
final_par <- structure(optim_result$par, names = names(parinit))
122+
123+
# Evaluate the objective function at the optimum.
124+
final_result <- objfn(final_par, ...)
125+
126+
# Attach optimization metadata.
127+
final_result$argument <- final_par
128+
final_result$converged <- !as.logical(optim_result$convergence)
129+
final_result$iterations <- optim_result$counts["function"]
130+
131+
return(final_result)
132+
}

0 commit comments

Comments
 (0)