Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"

[extensions]
JumpProcessesKernelAbstractionsExt = ["Adapt", "KernelAbstractions"]
JumpProcessesOrdinaryDiffEqCoreExt = "OrdinaryDiffEqCore"

[compat]
ADTypes = "1"
Expand All @@ -45,7 +47,7 @@ KernelAbstractions = "0.9"
LinearAlgebra = "1"
LinearSolve = "3"
OrdinaryDiffEq = "6"
OrdinaryDiffEqCore = "1.32.0"
OrdinaryDiffEqCore = "3"
Pkg = "1"
PoissonRandom = "0.4"
Random = "1"
Expand All @@ -69,7 +71,6 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"

[compat]
Catalyst = "14.0, 15"
Catalyst = "16"
DensityInterface = "0.4"
DifferentialEquations = "7.11"
Distributions = "0.25"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/assets/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"

[compat]
Catalyst = "14.0, 15"
Catalyst = "16"
DensityInterface = "0.4"
DifferentialEquations = "7.11"
Distributions = "0.25"
Expand Down
29 changes: 9 additions & 20 deletions docs/src/tutorials/discrete_stochastic_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,36 +190,25 @@ sir_model = @reaction_network begin
end
```

To build a pure jump process model of the reaction system, where the state is
constant between jumps, we will use a
[`DiscreteProblem`](https://docs.sciml.ai/DiffEqDocs/stable/types/discrete_types/).
This encodes that the state only changes at the jump times. We do this by giving
the constructor `u₀`, the initial condition, and `tspan`, the timespan. Here, we
will start with ``990`` susceptible people, ``10`` infected person, and `0` recovered
people, and solve the problem from `t=0.0` to `t=250.0`. We use the parameters
`β = 0.1/1000` and `ν = 0.01`. Thus, we build the problem via:
To build a pure jump process model of the reaction system we construct a
[`JumpProblem`](@ref) directly from the Catalyst `ReactionSystem`. We specify
`u₀`, the initial condition, `tspan`, the timespan, and `p`, the parameters.
Here, we will start with ``990`` susceptible people, ``10`` infected person, and
`0` recovered people, and solve the problem from `t=0.0` to `t=250.0`. We use
the parameters `β = 0.1/1000` and `ν = 0.01`.

```@example tut2
p = (:β => 0.1 / 1000, :ν => 0.01)
u₀ = [:S => 990, :I => 10, :R => 0]
tspan = (0.0, 250.0)
prob = DiscreteProblem(sir_model, u₀, tspan, p)
jump_prob = JumpProblem(sir_model, u₀, tspan, p)
```

*Notice, the initial populations are integers, since we want the exact number of
people in the different states.*

The Catalyst reaction network can be converted into various
DifferentialEquations.jl problem types, including `JumpProblem`s, `ODEProblem`s,
or `SDEProblem`s. To turn it into a [`JumpProblem`](@ref) representing the SIR jump
process model, we simply write

```@example tut2
jump_prob = JumpProblem(sir_model, prob, Direct())
```

Here `Direct()` indicates that we will determine the random times and types of
reactions using [Gillespie's Direct stochastic simulation algorithm
Here `Direct()` is the default aggregator, which determines the random times and
types of reactions using [Gillespie's Direct stochastic simulation algorithm
(SSA)](https://doi.org/10.1016/0021-9991(76)90041-3), also known as Doob's
method or Kinetic Monte Carlo. See [Jump Aggregators for Exact Simulation](@ref) for
other supported SSAs.
Expand Down
23 changes: 23 additions & 0 deletions ext/JumpProcessesOrdinaryDiffEqCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module JumpProcessesOrdinaryDiffEqCoreExt

using JumpProcesses
import DiffEqBase
import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, DAEAlgorithm

# Ambiguity fix: OrdinaryDiffEqCore defines
# __init(::Union{..., AbstractJumpProblem}, ::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm, ...})
# which is ambiguous with JumpProcesses'
# __init(::AbstractJumpProblem{P}, ::DEAlgorithm)
#
# IMPORTANT: Only ODE/DAE algorithms here. SDE/RODE algorithms are intentionally
# excluded because StochasticDiffEq defines its own __init for
# (JumpProblem, StochasticDiffEqAlgorithm) that handles jump-diffusion setup.
function DiffEqBase.__init(
_jump_prob::DiffEqBase.AbstractJumpProblem{P},
alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm};
merge_callbacks = true, kwargs...) where {P}
kwargs = DiffEqBase.merge_problem_kwargs(_jump_prob; merge_callbacks, kwargs...)
JumpProcesses.__jump_init(_jump_prob, alg; kwargs...)
end

end
6 changes: 4 additions & 2 deletions src/SSA_stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,10 @@ function should_continue_solve(integrator::SSAIntegrator)
integrator.keep_stepping && (has_jump || has_tstop)
end

function reset_aggregated_jumps!(integrator::SSAIntegrator, uprev = nothing)
reset_aggregated_jumps!(integrator, uprev, integrator.cb)
function reset_aggregated_jumps!(integrator::SSAIntegrator, uprev = nothing;
update_jump_params = true, kwargs...)
reset_aggregated_jumps!(integrator, uprev, integrator.cb;
update_jump_params, kwargs...)
nothing
end

Expand Down
17 changes: 16 additions & 1 deletion src/extended_jump_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,23 @@ function ArrayInterface.zeromatrix(A::ExtendedJumpArray)
u = [vec(A.u); vec(A.jump_u)]
u .* u' .* false
end

# Helper: concatenate fields into a flat vector, apply op, scatter back
function _eja_flat_apply_and_scatter!(op!, A, b::ExtendedJumpArray)
N = length(b.u)
tmp = [vec(b.u); vec(b.jump_u)]
op!(A, tmp)
copyto!(vec(b.u), 1, tmp, 1, N)
copyto!(vec(b.jump_u), 1, tmp, N + 1, length(b.jump_u))
b
end

function LinearAlgebra.ldiv!(A::LinearAlgebra.LU, b::ExtendedJumpArray)
LinearAlgebra.ldiv!(A, [vec(b.u); vec(b.jump_u)])
_eja_flat_apply_and_scatter!(LinearAlgebra.ldiv!, A, b)
end

function LinearAlgebra.lmul!(A::LinearAlgebra.AbstractQ, b::ExtendedJumpArray)
_eja_flat_apply_and_scatter!(LinearAlgebra.lmul!, A, b)
end

function recursivecopy!(dest::T, src::T) where {T <: ExtendedJumpArray}
Expand Down
43 changes: 42 additions & 1 deletion test/extended_jump_array.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Test, JumpProcesses, DiffEqBase, OrdinaryDiffEq, SciMLBase
using Test, JumpProcesses, DiffEqBase, OrdinaryDiffEq, SciMLBase, LinearAlgebra, LinearSolve
using FastBroadcast
using StableRNGs

Expand Down Expand Up @@ -118,3 +118,44 @@ let
@test eltype(sol.u) <: ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}
@test SciMLBase.plottable_indices(sol.u[1]) == 1:length(u₀)
end

# Test ldiv! and lmul! for stiff solver support
let rng = StableRNG(456)
u = rand(rng, 3)
jump_u = rand(rng, 2)
flat = [u; jump_u]

# ldiv! with LU should modify eja in place and match plain vector result
eja = ExtendedJumpArray(copy(u), copy(jump_u))
A = rand(rng, 5, 5) + 5I
F = lu(A)
expected = F \ flat
ldiv!(F, eja)
@test vcat(eja.u, eja.jump_u) ≈ expected

# lmul! with Q from QR
eja2 = ExtendedJumpArray(copy(u), copy(jump_u))
Q = qr(rand(rng, 5, 5)).Q
expected_q = Q * flat
lmul!(Q, eja2)
@test vcat(eja2.u, eja2.jump_u) ≈ expected_q

# lmul! with AdjointQ from QR (the actual CI failure case)
eja3 = ExtendedJumpArray(copy(u), copy(jump_u))
expected_qt = Q' * flat
lmul!(Q', eja3)
@test vcat(eja3.u, eja3.jump_u) ≈ expected_qt
end

# Integration test: stiff solver with QRFactorization and ExtendedJumpArray
let
f!(du, u, p, t) = (du .= 0; nothing)
rate(u, p, t) = 0.5
affect!(integrator) = (integrator.u[1] += 1; nothing)
vrj = VariableRateJump(rate, affect!)
oprob = ODEProblem(f!, [0.0], (0.0, 1.0))
jprob = JumpProblem(oprob, Direct(), vrj; vr_aggregator = VR_FRM(),
rng = StableRNG(789))
sol = solve(jprob, Rodas5P(linsolve = QRFactorization()))
@test sol.retcode == ReturnCode.Success
end
30 changes: 30 additions & 0 deletions test/ssa_callback_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,33 @@ let
@test 3.0 ∈ sol5.t
@test 6.0 ∈ sol5.t
end

# test that reset_aggregated_jumps! with update_jump_params kwarg dispatches correctly
# for SSAIntegrator (https://github.com/SciML/JumpProcesses.jl/issues/562)
let
rate1(u, p, t) = p[1] * u[1] * u[2]
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1)
jump1 = ConstantRateJump(rate1, affect1!)

rate2(u, p, t) = p[2] * u[2]
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1)
jump2 = ConstantRateJump(rate2, affect2!)

p = (1.0, 1.0)
prob = DiscreteProblem([990, 0, 0], (0.0, 250.0), p)
jump_prob = JumpProblem(prob, Direct(), jump1, jump2; rng)

# with update_jump_params = true, should behave same as without kwarg
int = init(jump_prob, SSAStepper())
int[2] = 10
reset_aggregated_jumps!(int; update_jump_params = true)
step!(int, 1000.0, true)
@test int.u[3] > 0 # at least some recovered, confirming jumps fired

# with update_jump_params = false, should also reset jump aggregation
int2 = init(jump_prob, SSAStepper())
int2[2] = 10
reset_aggregated_jumps!(int2; update_jump_params = false)
step!(int2, 1000.0, true)
@test int2.u[3] > 0
end
Loading