Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions src/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,23 +192,40 @@ function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAUL
remake(prob; f, u0)
end

struct VR_FRMEventCallback{F, RNG}
idx::Int
affect!::F
rng::RNG
end

# condition: (u, t, integrator)
@inline (c::VR_FRMEventCallback)(u, t, integrator) = u.jump_u[c.idx]

# affect: (integrator)
function (c::VR_FRMEventCallback)(integrator)
c.affect!(integrator)
integrator.u.jump_u[c.idx] = -randexp(c.rng, typeof(integrator.t))
nothing
end

# initialize: (cb, u, t, integrator)
function (c::VR_FRMEventCallback)(cb, u, t, integrator)
integrator.u.jump_u[c.idx] = -randexp(c.rng, typeof(integrator.t))
integrator.uprev.jump_u[c.idx] = integrator.u.jump_u[c.idx]
u_modified!(integrator, false)
nothing
end

function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG)
condition = function (u, t, integrator)
u.jump_u[idx]
end
affect! = function (integrator)
jump.affect!(integrator)
integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t))
nothing
end
new_cb = ContinuousCallback(condition, affect!;
cb_functor = VR_FRMEventCallback(idx, jump.affect!, rng)
ContinuousCallback(cb_functor, cb_functor;
initialize = cb_functor,
idxs = jump.idxs,
rootfind = jump.rootfind,
interp_points = jump.interp_points,
save_positions = jump.save_positions,
abstol = jump.abstol,
reltol = jump.reltol)
return new_cb
end

function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG)
Expand Down
2 changes: 1 addition & 1 deletion test/regular_jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function compute_mean_at_saves(sol, Nsims, npts, species_idx)
mean_vals = zeros(npts)
for i in 1:Nsims
for j in 1:npts
mean_vals[j] += sol[i].u[j][species_idx]
mean_vals[j] += sol.u[i].u[j][species_idx]
end
end
mean_vals ./= Nsims
Expand Down
49 changes: 29 additions & 20 deletions test/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ a .= b .+ c .+ d

rate = (u, p, t) -> u[1]
affect! = (integrator) -> (integrator.u[1] = integrator.u[1] / 2)
jump = VariableRateJump(rate, affect!, interp_points = 1000)
jump = VariableRateJump(rate, affect!)
jump2 = deepcopy(jump)

f = function (du, u, p, t)
Expand Down Expand Up @@ -271,12 +271,19 @@ end
# https://github.com/SciML/JumpProcesses.jl/issues/320
# note that even with the seeded StableRNG this test is not
# deterministic for some reason.
function getmean(Nsims, prob, alg, dt, tsave, seed)
function getmean(Nsims, prob, alg, tsave, seed)
umean = zeros(length(tsave))
for i in 1:Nsims
sol = solve(prob, alg; saveat = dt, seed)
umean .+= Array(sol(tsave; idxs = 1))
seed += 1
integrator = init(prob, alg; saveat = tsave, seed)
solve!(integrator)
for j in eachindex(umean)
umean[j] += integrator.sol.u[j][1]
end
for i in 2:Nsims
reinit!(integrator)
solve!(integrator)
for j in eachindex(umean)
umean[j] += integrator.sol.u[j][1]
end
end
umean ./= Nsims
return umean
Expand Down Expand Up @@ -304,23 +311,22 @@ let
integrator.u[1] += 1
nothing
end
b_jump = VariableRateJump(b_rate, birth!)
b_jump = VariableRateJump(b_rate, birth!; save_positions = (false, false))

d_rate(u, p, t) = (u[1] * p[2])
function death!(integrator)
integrator.u[1] -= 1
nothing
end
d_jump = VariableRateJump(d_rate, death!)
d_jump = VariableRateJump(d_rate, death!; save_positions = (false, false))

ode_prob = ODEProblem(ode_fxn, u0, tspan, p)
dt = 0.1
tsave = range(tspan[1], tspan[2]; step = dt)
tsave = range(tspan[1], tspan[2]; step = 0.1)
for vr_aggregator in (VR_Direct(), VR_DirectFW(), VR_FRM())
sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator, rng)

for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization()))
umean = getmean(Nsims, sjm_prob, alg, dt, tsave, seed)
umean = getmean(Nsims, sjm_prob, alg, tsave, seed)
@test all(abs.(umean .- n.(tsave)) .< 0.05 * n.(tsave))
seed += Nsims
end
Expand All @@ -333,16 +339,19 @@ end
function run_ensemble(prob, alg, jumps...; vr_aggregator = VR_FRM(), Nsims = 8000)
rng = StableRNG(12345)
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator, rng)
ensemble = EnsembleProblem(jump_prob)
sol = solve(ensemble, alg, trajectories = Nsims, save_everystep = false)
return mean(sol.u[i][1, end] for i in 1:Nsims)
total = 0.0
for i in 1:Nsims
sol = solve(jump_prob, alg; save_everystep = false)
total += sol.u[end][1]
end
return total / Nsims
end

# Test 1: Simple ODE with two variable rate jumps
let
rate = (u, p, t) -> u[1]
affect! = (integrator) -> (integrator.u[1] = integrator.u[1] / 2)
jump = VariableRateJump(rate, affect!, interp_points = 1000)
jump = VariableRateJump(rate, affect!; save_positions = (false, false))
jump2 = deepcopy(jump)

f = (du, u, p, t) -> (du[1] = u[1])
Expand All @@ -362,7 +371,7 @@ let
g = (du, u, p, t) -> (du[1] = -u[1] / 10.0)
rate = (u, p, t) -> u[1] / 10.0
affect! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1)
jump = VariableRateJump(rate, affect!)
jump = VariableRateJump(rate, affect!; save_positions = (false, false))
jump2 = deepcopy(jump)

prob = SDEProblem(f, g, [10.0], (0.0, 10.0))
Expand All @@ -381,7 +390,7 @@ let
f = (du, u, p, t) -> (du[1] = -u[1]; nothing)
rate = (u, p, t) -> λ
affect! = (integrator) -> (integrator.u[1] += 1; nothing)
jump = VariableRateJump(rate, affect!)
jump = VariableRateJump(rate, affect!; save_positions = (false, false))

prob = ODEProblem(f, [0.2], (0.0, 10.0))

Expand Down Expand Up @@ -409,7 +418,7 @@ let
integrator.p[3] += 1
nothing
end
birth_jump = VariableRateJump(birth_rate, birth_affect!)
birth_jump = VariableRateJump(birth_rate, birth_affect!; save_positions = (false, false))

# Define death jump: X → ∅
death_rate(u, p, t) = 0.5 * u[1]
Expand All @@ -418,7 +427,7 @@ let
integrator.p[3] += 1
nothing
end
death_jump = VariableRateJump(death_rate, death_affect!)
death_jump = VariableRateJump(death_rate, death_affect!; save_positions = (false, false))

Nsims = 100
results = Dict()
Expand All @@ -431,7 +440,7 @@ let
jump_prob = JumpProblem(prob, Direct(), birth_jump, death_jump; vr_aggregator, rng)

for i in 1:Nsims
sol = solve(jump_prob, Tsit5())
sol = solve(jump_prob, Tsit5(); save_everystep = false)
jump_counts[i] = jump_prob.prob.p[3]
jump_prob.prob.p[3] = 0
end
Expand Down
Loading