Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 101 additions & 31 deletions src/simple_regular_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,45 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleExplici
jump_prob.massaction_jump !== nothing
end

"""
_process_saveat(saveat, tspan, save_start, save_end)

Process `saveat` into a sorted vector of strictly interior save times (excluding
both `tspan` endpoints), and resolve `save_start`/`save_end` defaults following
OrdinaryDiffEq conventions.

Endpoint saving is controlled purely by the returned `save_start`/`save_end`
flags. When the user passes `nothing` for these, defaults are:
- No saveat or saveat is a Number: `true` for both.
- saveat is a collection: `true` if the corresponding endpoint is `in` the collection.
"""
function _process_saveat(saveat, tspan, save_start, save_end)
t0, tf = tspan
if isnothing(saveat)
saveat_vec = Vector{typeof(t0)}()
_save_start = something(save_start, true)
_save_end = something(save_end, true)
elseif saveat isa Number
saveat_vec = collect(t0 + saveat:saveat:tf)
if !isempty(saveat_vec) && last(saveat_vec) == tf
pop!(saveat_vec)
end
_save_start = something(save_start, true)
_save_end = something(save_end, true)
else
saveat_vec = sort!(collect(saveat))
_save_start = something(save_start, insorted(t0, saveat_vec))
_save_end = something(save_end, insorted(tf, saveat_vec))
lo = searchsortedlast(saveat_vec, t0) + 1
hi = searchsortedfirst(saveat_vec, tf) - 1
saveat_vec = saveat_vec[lo:hi]
end
return saveat_vec, _save_start, _save_end
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
seed = nothing, dt = error("dt is required for SimpleTauLeaping."))
seed = nothing, dt = error("dt is required for SimpleTauLeaping."),
saveat = nothing, save_start = nothing, save_end = nothing)
validate_pure_leaping_inputs(jump_prob, alg) ||
error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only RegularJumps.")

Expand All @@ -58,26 +95,56 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
p = prob.p

n = Int((tspan[2] - tspan[1]) / dt) + 1
u = Vector{typeof(prob.u0)}(undef, n)
u[1] = u0
t = tspan[1]:dt:tspan[2]

# iteration variables
counts = zero(rate_cache) # counts for each variable
saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end)

if save_start
usave = [copy(u0)]
tsave = typeof(tspan[1])[tspan[1]]
else
usave = typeof(u0)[]
tsave = typeof(tspan[1])[]
end
save_idx = 1

# Pre-allocate working buffers — swap each step to avoid copying
uprev = u0 # u0 is already a copy
u_new = similar(u0)
counts = zero(rate_cache)

for i in 2:n # iterate over dt-slices
uprev = u[i - 1]
tprev = t[i - 1]
for i in 2:n
tprev = tspan[1] + (i - 2) * dt
t_new = tprev + dt
rate(rate_cache, uprev, p, tprev)
rate_cache .*= dt # multiply by the width of the time interval
counts .= pois_rand.((rng,), rate_cache) # set counts to the poisson arrivals with our given rates
rate_cache .*= dt
counts .= pois_rand.((rng,), rate_cache)
c(du, uprev, p, tprev, counts, mark)
u[i] = du + uprev
u_new .= du .+ uprev

# Save logic — only allocate (via copy) when actually saving
if isempty(saveat_times)
push!(usave, copy(u_new))
push!(tsave, t_new)
else
while save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]
push!(usave, copy(u_new))
push!(tsave, saveat_times[save_idx])
save_idx += 1
end
end

uprev, u_new = u_new, uprev
end

sol = DiffEqBase.build_solution(prob, alg, t, u,
# Save endpoint if requested and not already saved
if save_end && (isempty(tsave) || tsave[end] != tspan[2])
push!(usave, copy(uprev))
push!(tsave, tspan[2])
end

sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
calculate_error = false,
interp = DiffEqBase.ConstantInterpolation(t, u))
interp = DiffEqBase.ConstantInterpolation(tsave, usave))
end

# Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV.
Expand Down Expand Up @@ -202,15 +269,13 @@ end
function simple_explicit_tau_leaping_loop!(
prob, alg, u_current, u_new, t_current, t_end, p, rng,
rate, c, nu, hor, max_hor, max_stoich, numjumps, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj)
dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj,
save_end)
save_idx = 1

while t_current < t_end
rate(rate_cache, u_current, p, t_current)
if all(<=(0), rate_cache) # No reactions can occur, step to final time
# Save final state at t_end
push!(usave, copy(u_current))
push!(tsave, t_end)
t_current = t_end
break
end
Expand Down Expand Up @@ -261,12 +326,18 @@ function simple_explicit_tau_leaping_loop!(
u_current .= u_new
t_current = t_new
end

# Save endpoint if requested and not already saved
if save_end && (isempty(tsave) || tsave[end] != t_end)
push!(usave, copy(u_current))
push!(tsave, t_end)
end
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping;
seed = nothing,
dtmin = nothing,
saveat = nothing)
saveat = nothing, save_start = nothing, save_end = nothing)
validate_pure_leaping_inputs(jump_prob, alg) ||
error("SimpleExplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")

Expand All @@ -289,12 +360,19 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping;
u0 = copy(prob.u0)
p = prob.p

saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end)

# Initialize current state and saved history
u_current = copy(u0)
u_new = similar(u0)
t_current = tspan[1]
usave = [copy(u0)]
tsave = [tspan[1]]
if save_start
usave = [copy(u0)]
tsave = [tspan[1]]
else
usave = typeof(u0)[]
tsave = typeof(tspan[1])[]
end
rate_cache = zeros(float(eltype(u0)), numjumps)
rate_effective = similar(rate_cache)
counts = zero(rate_cache)
Expand All @@ -315,19 +393,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping;
max_hor, max_stoich = precompute_reaction_conditions(
reactant_stoch, hor, length(u0), numjumps)

# Set up saveat_times
if isnothing(saveat)
saveat_times = Vector{typeof(tspan[1])}()
elseif saveat isa Number
saveat_times = collect(range(tspan[1], tspan[2], step = saveat))
else
saveat_times = collect(saveat)
end

simple_explicit_tau_leaping_loop!(
prob, alg, u_current, u_new, t_current, t_end, p, rng,
rate, c, nu, hor, max_hor, max_stoich, numjumps, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj)
dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj,
save_end)

sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
calculate_error = false,
Expand Down
Loading
Loading