Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
f5ad900
Implemented SimpleAdaptiveTauLeaping and SimpleImplicitTauLeaping
sivasathyaseeelan Aug 9, 2025
aae5d9f
update project.toml
sivasathyaseeelan Aug 10, 2025
230d508
test changes
sivasathyaseeelan Aug 10, 2025
25b8c16
refactor
sivasathyaseeelan Aug 12, 2025
07c429e
test refactor
sivasathyaseeelan Aug 12, 2025
885ac59
refactor
sivasathyaseeelan Aug 13, 2025
0ec9d39
added saveat in SimpleAdaptiveTauLeaping
sivasathyaseeelan Aug 16, 2025
0e7ff26
update
sivasathyaseeelan Aug 24, 2025
89378c6
Update src/simple_regular_solve.jl
sivasathyaseeelan Aug 24, 2025
14e0be7
Update src/simple_regular_solve.jl
sivasathyaseeelan Aug 24, 2025
e7f975e
update
sivasathyaseeelan Aug 24, 2025
6e789cd
test update
sivasathyaseeelan Aug 24, 2025
a8999f4
using maj for adaptive tauleaping
sivasathyaseeelan Aug 24, 2025
0125e21
project.toml update
sivasathyaseeelan Aug 24, 2025
1af190d
saveat logic change
sivasathyaseeelan Aug 24, 2025
10f4ce3
test change
sivasathyaseeelan Aug 24, 2025
6d3d900
saveat optimization
sivasathyaseeelan Aug 25, 2025
2c03d67
refactor
sivasathyaseeelan Aug 25, 2025
3f90750
memory optimization
sivasathyaseeelan Aug 25, 2025
fb72149
validate_pure_leaping_inputs extended for adaptive version
sivasathyaseeelan Aug 25, 2025
7a7232a
some
sivasathyaseeelan Aug 25, 2025
fe7cec0
space optimized in compute_tau_explicit
sivasathyaseeelan Aug 25, 2025
8e7ff16
computegi and comutehor changes
sivasathyaseeelan Aug 25, 2025
bc770d1
reactant_stoch in hor
sivasathyaseeelan Aug 25, 2025
b5f77f5
compute_gi update
sivasathyaseeelan Aug 26, 2025
0b72d4c
added references
sivasathyaseeelan Aug 26, 2025
5415947
added unpack
sivasathyaseeelan Aug 27, 2025
b39390f
test changes
sivasathyaseeelan Aug 27, 2025
822562f
test changes
sivasathyaseeelan Aug 27, 2025
b572987
export changes
sivasathyaseeelan Aug 27, 2025
785266b
test changes
sivasathyaseeelan Aug 27, 2025
b47df7c
some change in gi calculation
sivasathyaseeelan Aug 28, 2025
e02d432
changed compute_gi as per paper
sivasathyaseeelan Aug 28, 2025
092d361
some
sivasathyaseeelan Aug 28, 2025
7217cf0
some
sivasathyaseeelan Aug 28, 2025
cc3a78a
optimized compute_gi
sivasathyaseeelan Aug 29, 2025
48fece2
zero rates case for SimpleAdaptiveTauLeaping is added
sivasathyaseeelan Aug 30, 2025
12f84ba
SimpleExplicitTauLeaping
sivasathyaseeelan Sep 5, 2025
98d64f3
test update
sivasathyaseeelan Sep 5, 2025
00eb432
Implemented SimpleAdaptiveTauLeaping and SimpleImplicitTauLeaping
sivasathyaseeelan Aug 9, 2025
0be66a2
update project.toml
sivasathyaseeelan Aug 10, 2025
3d63e23
test changes
sivasathyaseeelan Aug 10, 2025
5d2a7cb
refactor
sivasathyaseeelan Aug 12, 2025
f19ed98
test refactor
sivasathyaseeelan Aug 12, 2025
a5f6852
refactor
sivasathyaseeelan Aug 13, 2025
4a11df4
added saveat in SimpleAdaptiveTauLeaping
sivasathyaseeelan Aug 16, 2025
2d58c88
update
sivasathyaseeelan Aug 24, 2025
b632a0f
Update src/simple_regular_solve.jl
sivasathyaseeelan Aug 24, 2025
0a2664c
Update src/simple_regular_solve.jl
sivasathyaseeelan Aug 24, 2025
c77c119
update
sivasathyaseeelan Aug 24, 2025
69a5442
test update
sivasathyaseeelan Aug 24, 2025
55cf8ec
using maj for adaptive tauleaping
sivasathyaseeelan Aug 24, 2025
9aea095
project.toml update
sivasathyaseeelan Aug 24, 2025
dbc33c2
saveat logic change
sivasathyaseeelan Aug 24, 2025
450daca
test change
sivasathyaseeelan Aug 24, 2025
bb1f5df
saveat optimization
sivasathyaseeelan Aug 25, 2025
a0af5ce
refactor
sivasathyaseeelan Aug 25, 2025
0d3b241
memory optimization
sivasathyaseeelan Aug 25, 2025
d8507b8
validate_pure_leaping_inputs extended for adaptive version
sivasathyaseeelan Aug 25, 2025
df03880
some
sivasathyaseeelan Aug 25, 2025
c47a75e
space optimized in compute_tau_explicit
sivasathyaseeelan Aug 25, 2025
72de260
computegi and comutehor changes
sivasathyaseeelan Aug 25, 2025
cf5d3f4
reactant_stoch in hor
sivasathyaseeelan Aug 25, 2025
e8eb3d1
compute_gi update
sivasathyaseeelan Aug 26, 2025
1028e68
added references
sivasathyaseeelan Aug 26, 2025
710b6cf
added unpack
sivasathyaseeelan Aug 27, 2025
b82ee00
test changes
sivasathyaseeelan Aug 27, 2025
951c071
test changes
sivasathyaseeelan Aug 27, 2025
9441d25
export changes
sivasathyaseeelan Aug 27, 2025
b0623ce
test changes
sivasathyaseeelan Aug 27, 2025
3cc5341
some change in gi calculation
sivasathyaseeelan Aug 28, 2025
ede1f22
changed compute_gi as per paper
sivasathyaseeelan Aug 28, 2025
20e9d37
some
sivasathyaseeelan Aug 28, 2025
da5b03c
some
sivasathyaseeelan Aug 28, 2025
3b14401
optimized compute_gi
sivasathyaseeelan Aug 29, 2025
3d19c42
zero rates case for SimpleAdaptiveTauLeaping is added
sivasathyaseeelan Aug 30, 2025
c8f1106
SimpleExplicitTauLeaping
sivasathyaseeelan Sep 5, 2025
8baf04e
test update
sivasathyaseeelan Sep 5, 2025
7fd548d
addressed some comments
sivasathyaseeelan Jan 16, 2026
a2b9f1b
comments above function
sivasathyaseeelan Jan 16, 2026
ae9bcf5
removed hardcoded type
sivasathyaseeelan Jan 16, 2026
7ac89f6
removed some hardcoded floating points
sivasathyaseeelan Jan 16, 2026
e49d170
Merge branch 'SciML:master' into adaptive-implicit
sivasathyaseeelan Jan 16, 2026
9424834
unpack fix
sivasathyaseeelan Jan 16, 2026
4caeede
negative case handel
sivasathyaseeelan Jan 17, 2026
4e800be
hardcoded value remove
sivasathyaseeelan Jan 17, 2026
79cea7e
hardcoded value remove
sivasathyaseeelan Jan 17, 2026
dbde117
type fix
sivasathyaseeelan Jan 17, 2026
522f63d
remove unwanted code
sivasathyaseeelan Jan 17, 2026
90cf701
posi rand optimization
sivasathyaseeelan Jan 17, 2026
2a6db45
some reviews resolved
sivasathyaseeelan Jan 17, 2026
3fbcf4f
u_new allocated
sivasathyaseeelan Jan 17, 2026
3477eb8
stand alone function for mass action rate
sivasathyaseeelan Jan 17, 2026
df7f485
simple explicit loop
sivasathyaseeelan Jan 17, 2026
c237d7c
bug and lint fix
sivasathyaseeelan Jan 18, 2026
c601fd3
bug fix
sivasathyaseeelan Jan 18, 2026
0e5cb7b
update gitignore
isaacsas Jan 26, 2026
7ff8a83
update to master
isaacsas Jan 26, 2026
386cd8d
minor fixes
isaacsas Jan 26, 2026
c858d83
use seed! consistently
isaacsas Jan 26, 2026
1f1f6a8
ci bug fix
sivasathyaseeelan Feb 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@
*.jl.*.cov
*.jl.mem
Manifest.toml
.vscode/
.claude/
.claude/*

# Vim files
*.swp

# vscode stuff
.vscode
.vscode/*

LocalPreferences.toml

# claude
.claude
.claude/*
.claude_plans
.claude_plans/*
4 changes: 2 additions & 2 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Reexport: Reexport, @reexport

# Explicit imports from standard libraries
using LinearAlgebra: LinearAlgebra, mul!
using Random: Random, randexp, randexp!
using Random: Random, randexp, randexp!, seed!

# Explicit imports from external packages
using DocStringExtensions: DocStringExtensions, FIELDS, TYPEDEF
Expand Down Expand Up @@ -128,7 +128,7 @@ export SSAStepper

# leaping:
include("simple_regular_solve.jl")
export SimpleTauLeaping, EnsembleGPUKernel
export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel

# spatial:
include("spatial/spatial_massaction_jump.jl")
Expand Down
286 changes: 280 additions & 6 deletions src/simple_regular_solve.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
struct SimpleTauLeaping <: DiffEqBase.DEAlgorithm end

struct SimpleExplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm
epsilon::T # Error control parameter
end

SimpleExplicitTauLeaping(; epsilon = 0.05) = SimpleExplicitTauLeaping(epsilon)

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \
Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release."
end
isempty(jump_prob.jump_callback.continuous_callbacks) &&
isempty(jump_prob.jump_callback.discrete_callbacks) &&
isempty(jump_prob.constant_jumps) &&
isempty(jump_prob.variable_jumps) &&
get_num_majumps(jump_prob.massaction_jump) == 0 &&
jump_prob.regular_jump !== nothing
isempty(jump_prob.jump_callback.discrete_callbacks) &&
isempty(jump_prob.constant_jumps) &&
isempty(jump_prob.variable_jumps) &&
get_num_majumps(jump_prob.massaction_jump) == 0 &&
jump_prob.regular_jump !== nothing
end

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \
Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release."
end
isempty(jump_prob.jump_callback.continuous_callbacks) &&
isempty(jump_prob.jump_callback.discrete_callbacks) &&
isempty(jump_prob.constant_jumps) &&
isempty(jump_prob.variable_jumps) &&
jump_prob.massaction_jump !== nothing
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
Expand All @@ -20,7 +39,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only RegularJumps.")

(; prob, rng) = jump_prob
(seed !== nothing) && Random.seed!(rng, seed)
(seed !== nothing) && seed!(rng, seed)

rj = jump_prob.regular_jump
rate = rj.rate # rate function rate(out,u,p,t)
Expand Down Expand Up @@ -61,6 +80,261 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
interp = DiffEqBase.ConstantInterpolation(t, u))
end

# Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV.
# HOR is the sum of stoichiometric coefficients of reactants in reaction j.
# Extract the element type from reactant_stoch to avoid hardcoding type assumptions.
function compute_hor(reactant_stoch, numjumps)
stoch_type = eltype(first(first(reactant_stoch)))
hor = zeros(stoch_type, numjumps)
for j in 1:numjumps
order = sum(
stoch for (spec_idx, stoch) in reactant_stoch[j]; init = zero(stoch_type))
if order > 3
error("Reaction $j has order $order, which is not supported (maximum order is 3).")
end
hor[j] = order
end
return hor
end

# Precompute reaction conditions for each species i, including:
# - max_hor: the highest order of reaction (HOR) where species i is a reactant.
# - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor.
# Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27).
function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps)
hor_type = eltype(hor)
max_hor = zeros(hor_type, numspecies)
max_stoich = zeros(hor_type, numspecies)
for j in 1:numjumps
for (spec_idx, stoch) in reactant_stoch[j]
if stoch > 0 # Species is a reactant
if hor[j] > max_hor[spec_idx]
max_hor[spec_idx] = hor[j]
max_stoich[spec_idx] = stoch
elseif hor[j] == max_hor[spec_idx]
max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch)
end
end
end
end
return max_hor, max_stoich
end

# Compute g_i for species i to bound the relative change in propensity functions,
# as per Cao et al. (2006), Section IV, equation (27).
# g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant:
# - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1
# - HOR = 2 (second-order):
# - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2
# - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1)
# - HOR = 3 (third-order):
# - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3
# - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1))
# - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2)
# Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep.
function compute_gi(u, max_hor, max_stoich, i, t)
one_max_hor = one(1 / one(eltype(u)))

if max_hor[i] == 0 # No reactions involve species i as a reactant
return one_max_hor
elseif max_hor[i] == 1
return one_max_hor
elseif max_hor[i] == 2
if max_stoich[i] == 1
return 2 * one_max_hor
else # if max_stoich[i] == 2
return u[i] > one_max_hor ?
2 * one_max_hor + one_max_hor / (u[i] - one_max_hor) : 2 * one_max_hor # Fallback to 2 if x_i <= 1
end
elseif max_hor[i] == 3
if max_stoich[i] == 1
return 3 * one_max_hor
elseif max_stoich[i] == 2
return u[i] > one_max_hor ?
(3 * one_max_hor / 2) *
(2 * one_max_hor + one_max_hor / (u[i] - one_max_hor)) : 3 * one_max_hor # Fallback to 3 if x_i <= 1
else # if max_stoich[i] == 3
return u[i] > 2 * one_max_hor ?
3 * one_max_hor + one_max_hor / (u[i] - one_max_hor) +
2 * one_max_hor / (u[i] - 2 * one_max_hor) : 3 * one_max_hor # Fallback to 3 if x_i <= 2
end
end
return one_max_hor # Default case
end

# Compute the tau-leaping step-size using equation (20) from Cao et al. (2006):
# tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) }
# where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b):
# mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x)
# I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified).
function compute_tau(
u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
rate(rate_cache, u, p, t)
if all(<=(0), rate_cache) # Handle case where all rates are zero or negative
return dtmin
end
tau = typemax(typeof(t))
for i in 1:length(u)
mu = zero(eltype(u))
sigma2 = zero(eltype(u))
for j in 1:size(nu, 2)
mu += nu[i, j] * rate_cache[j] # Equation (9a)
sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b)
end
gi = compute_gi(u, max_hor, max_stoich, i, t)
bound = max(epsilon * u[i] / gi, one(eltype(u))) # max(epsilon * x_i / g_i, 1)
mu_term = abs(mu) > 0 ? bound / abs(mu) : typemax(typeof(t)) # First term in equation (8)
sigma_term = sigma2 > 0 ? bound^2 / sigma2 : typemax(typeof(t)) # Second term in equation (8)
tau = min(tau, mu_term, sigma_term) # Equation (8)
end
return max(tau, dtmin)
Comment thread
isaacsas marked this conversation as resolved.
end

# Function to generate a mass action rate function
function massaction_rate(maj, numjumps)
return (out, u, p, t) -> begin
for j in 1:numjumps
out[j] = evalrxrate(u, j, maj)
end
end
end

function simple_explicit_tau_leaping_loop!(
prob, alg, u_current, u_new, t_current, t_end, p, rng,
rate, c, nu, hor, max_hor, max_stoich, numjumps, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj)
save_idx = 1

while t_current < t_end
rate(rate_cache, u_current, p, t_current)
if all(<=(0), rate_cache) # No reactions can occur, step to final time
# Save final state at t_end
push!(usave, copy(u_current))
push!(tsave, t_end)
t_current = t_end
break
end
tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current,
epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
tau = min(tau, t_end - t_current)
if !isempty(saveat_times) && save_idx <= length(saveat_times) &&
t_current + tau > saveat_times[save_idx]
tau = saveat_times[save_idx] - t_current
end
# Calculate Poisson random numbers only for positive rates
rate_effective .= rate_cache .* tau
for j in eachindex(counts)
if rate_effective[j] <= zero(eltype(rate_effective))
counts[j] = zero(eltype(counts))
else
counts[j] = pois_rand(rng, rate_effective[j])
end
end
du .= 0
if c !== nothing
c(du, u_current, p, t_current, counts, nothing)
else
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
du[spec_idx] += stoch * counts[j]
end
end
end
u_new .= u_current .+ du
if any(<(0), u_new)
# Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3
tau /= 2
continue
end
t_new = t_current + tau

# Save state if at a saveat time or if saveat is empty
if isempty(saveat_times) ||
(save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
push!(usave, copy(u_new))
push!(tsave, t_new)
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
save_idx += 1
end
end

u_current .= u_new
t_current = t_new
end
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping;
seed = nothing,
dtmin = nothing,
saveat = nothing)
validate_pure_leaping_inputs(jump_prob, alg) ||
error("SimpleExplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")

prob = jump_prob.prob
rng = jump_prob.rng
tspan = prob.tspan

if dtmin === nothing
dtmin = 1e-10 * one(typeof(tspan[2]))
end

(seed !== nothing) && seed!(rng, seed)

maj = jump_prob.massaction_jump
numjumps = get_num_majumps(maj)
rj = jump_prob.regular_jump
# Extract rates
rate = rj !== nothing ? rj.rate : massaction_rate(maj, numjumps)
c = rj !== nothing ? rj.c : nothing
u0 = copy(prob.u0)
p = prob.p

# Initialize current state and saved history
u_current = copy(u0)
u_new = similar(u0)
t_current = tspan[1]
usave = [copy(u0)]
tsave = [tspan[1]]
rate_cache = zeros(float(eltype(u0)), numjumps)
rate_effective = similar(rate_cache)
counts = zero(rate_cache)
du = similar(u0)
t_end = tspan[2]
epsilon = alg.epsilon

# Extract net stoichiometry for state updates
nu = zeros(float(eltype(u0)), length(u0), numjumps)
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
nu[spec_idx, j] = stoch
end
end
# Extract reactant stoichiometry for hor and gi
reactant_stoch = maj.reactant_stoch
hor = compute_hor(reactant_stoch, numjumps)
max_hor, max_stoich = precompute_reaction_conditions(
reactant_stoch, hor, length(u0), numjumps)

# Set up saveat_times
if isnothing(saveat)
saveat_times = Vector{typeof(tspan[1])}()
elseif saveat isa Number
saveat_times = collect(range(tspan[1], tspan[2], step = saveat))
else
saveat_times = collect(saveat)
end
Comment thread
sivasathyaseeelan marked this conversation as resolved.

simple_explicit_tau_leaping_loop!(
prob, alg, u_current, u_new, t_current, t_end, p, rng,
rate, c, nu, hor, max_hor, max_stoich, numjumps, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj)

sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
calculate_error = false,
interp = DiffEqBase.ConstantInterpolation(tsave, usave))
return sol
end

struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
backend::Backend
cpu_offload::Float64
Expand Down
Loading
Loading