Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Reexport: Reexport, @reexport

# Explicit imports from standard libraries
using LinearAlgebra: LinearAlgebra, mul!
using Random: Random, randexp, randexp!, seed!
using Random: Random, randexp, seed!

# Explicit imports from external packages
using DocStringExtensions: DocStringExtensions, FIELDS, TYPEDEF
Expand Down
26 changes: 11 additions & 15 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,26 @@ function __jump_init(_jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg;
end
end

# Derive an independent seed from the caller's seed. When a caller (e.g. StochasticDiffEq)
# passes the same seed used for its noise process, we must produce a distinct seed for the
# jump aggregator's RNG. We cannot assume the JumpProblem's stored RNG is any particular
# type, so we pass the seed through `hash` (to decorrelate from the input) and then through
# a Xoshiro draw (to ensure strong mixing regardless of the target RNG's seeding quality).
const _JUMP_SEED_SALT = 0x4a756d7050726f63 # "JumPProc" in ASCII
_derive_jump_seed(seed) = rand(Random.Xoshiro(hash(seed, _JUMP_SEED_SALT)), UInt64)
Comment on lines +73 to +74
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisRackauckas I'm not sure what the best approach to take here is. The problem is we don't want to just use Xoshiro(seed) as that will sample from the same stream that StochasticDiffEq is using (currently, this will all go away when we get the rng updates here too, but I'd like something right now we can have here for a non-breaking release fix).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My impression is the Xoshiro seeding should be quite good at mixing, so I'm hoping this is reasonable...


function resetted_jump_problem(_jump_prob, seed)
jump_prob = deepcopy(_jump_prob)
# Only reseed if an explicit seed is provided. This respects the user's RNG choice
# and enables reproducibility. For EnsembleProblems, use prob_func to set unique seeds
# for each trajectory if different results are needed.
if seed !== nothing && !isempty(jump_prob.jump_callback.discrete_callbacks)
rng = jump_prob.jump_callback.discrete_callbacks[1].condition.rng
Random.seed!(rng, seed)
end

if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray
randexp!(_jump_prob.rng, jump_prob.prob.u0.jump_u)
jump_prob.prob.u0.jump_u .*= -1
Random.seed!(rng, _derive_jump_seed(seed))
end
jump_prob
end

function reset_jump_problem!(jump_prob, seed)
if seed !== nothing && !isempty(jump_prob.jump_callback.discrete_callbacks)
Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, seed)
end

if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray
randexp!(jump_prob.rng, jump_prob.prob.u0.jump_u)
jump_prob.prob.u0.jump_u .*= -1
Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng,
_derive_jump_seed(seed))
end
end
201 changes: 201 additions & 0 deletions test/ensemble_problems.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
using JumpProcesses, OrdinaryDiffEq, StochasticDiffEq, Test
using StableRNGs, Random

# ==========================================================================
# Problem constructors
# ==========================================================================

# Constant-rate birth-death for SSAStepper / ODE-coupled tests
function make_ssa_jump_prob(; rng = StableRNG(12345))
j1 = ConstantRateJump((u, p, t) -> 10.0, integrator -> (integrator.u[1] += 1))
j2 = ConstantRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1))
dprob = DiscreteProblem([10], (0.0, 20.0))
JumpProblem(dprob, Direct(), j1, j2; rng)
end

# ODE + variable-rate jump
function make_vr_jump_prob(agg; rng = StableRNG(12345))
f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing)
oprob = ODEProblem(f!, [100.0], (0.0, 10.0))
vrj = VariableRateJump((u, p, t) -> 0.5 * u[1],
integrator -> (integrator.u[1] -= 1.0))
JumpProblem(oprob, Direct(), vrj; vr_aggregator = agg, rng)
end

# SDE + variable-rate jump
function make_sde_vr_jump_prob(agg; rng = StableRNG(12345))
f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing)
g!(du, u, p, t) = (du[1] = 0.1 * u[1]; nothing)
sprob = SDEProblem(f!, g!, [100.0], (0.0, 10.0))
vrj = VariableRateJump((u, p, t) -> 0.5 * u[1],
integrator -> (integrator.u[1] -= 1.0))
JumpProblem(sprob, Direct(), vrj; vr_aggregator = agg, rng)
end

# Helpers
first_jump_time(traj) = traj.t[2]

# ==========================================================================
# 1. Serial ensemble: sequential trajectories get different RNG streams
# ==========================================================================

@testset "EnsembleSerial: distinct streams" begin
@testset "SSAStepper" begin
jprob = make_ssa_jump_prob()
sol = solve(EnsembleProblem(jprob), SSAStepper(), EnsembleSerial();
trajectories = 3)
times = [first_jump_time(sol.u[i]) for i in 1:3]
@test allunique(times)
end

@testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW())
jprob = make_vr_jump_prob(agg)
sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleSerial();
trajectories = 3)
times = [first_jump_time(sol.u[i]) for i in 1:3]
@test allunique(times)
finals = [sol.u[i].u[end][1] for i in 1:3]
@test allunique(finals)
end

# EM() uses a fixed time grid so jump event times aren't directly visible
# in t[2]; we check final values instead.
@testset "SDE + VR (VR_FRM)" begin
jprob = make_sde_vr_jump_prob(VR_FRM())
sol = solve(EnsembleProblem(jprob), EM(), EnsembleSerial();
trajectories = 3, dt = 0.01, save_everystep = false)
finals = [sol.u[i].u[end][1] for i in 1:3]
@test allunique(finals)
end
end

# ==========================================================================
# 2. Sequential solves on same thread: RNG advances between solves
# ==========================================================================

@testset "Sequential solves: different RNG streams" begin
@testset "SSAStepper" begin
jprob = make_ssa_jump_prob()
times = [first_jump_time(solve(jprob, SSAStepper())) for _ in 1:3]
@test allunique(times)
end

@testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW())
jprob = make_vr_jump_prob(agg)
sols = [solve(jprob, Tsit5()) for _ in 1:3]
times = [first_jump_time(s) for s in sols]
@test allunique(times)
finals = [s.u[end][1] for s in sols]
@test allunique(finals)
end
end

# ==========================================================================
# 3. Threaded ensemble: no data race on the shared JumpProblem
#
# The ODE/SSA path through __jump_init receives seed=nothing from
# SciMLBase, so deepcopy'd problems on non-main threads start with
# identical RNG states. We only assert completion here — uniqueness
# requires explicit seeding (tested in section 4 below).
#
# The SDE path goes through StochasticDiffEq's __init which generates
# per-trajectory seeds, so we can additionally verify uniqueness there.
# ==========================================================================

@testset "EnsembleThreads: no data race" begin
@testset "SSAStepper" begin
jprob = make_ssa_jump_prob()
sol = solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads();
trajectories = 4)
@test length(sol) == 4
end

@testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW())
jprob = make_vr_jump_prob(agg)
# This path previously had a data race: resetted_jump_problem called
# randexp!(_jump_prob.rng, ...) on the shared original problem.
sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleThreads();
trajectories = 4, save_everystep = false)
@test length(sol) == 4
end

@testset "SDE + VR (VR_FRM): unique trajectories" begin
jprob = make_sde_vr_jump_prob(VR_FRM())
# StochasticDiffEq generates per-trajectory seeds and passes them to
# resetted_jump_problem, so trajectories should be distinct.
sol = solve(EnsembleProblem(jprob), EM(), EnsembleThreads();
trajectories = 4, dt = 0.01, save_everystep = false)
@test length(sol) == 4
finals = [sol.u[i].u[end][1] for i in 1:4]
@test length(unique(finals)) > 1
end
end

# ==========================================================================
# 4. Seed-based stream independence: resetted_jump_problem and
# reset_jump_problem! produce distinct RNG streams for different seeds
#
# This tests the mechanism that EnsembleThreads relies on (when seeds are
# provided by the caller, e.g. StochasticDiffEq) to get independent streams
# on different threads.
# ==========================================================================

@testset "resetted_jump_problem: different seeds → different streams" begin
jprob = make_ssa_jump_prob()
seeds = UInt64[100, 200, 300]

# Each seed should produce a distinct aggregator RNG state
rngs = map(seeds) do s
jp = JumpProcesses.resetted_jump_problem(jprob, s)
jp.jump_callback.discrete_callbacks[1].condition.rng
end
draws = [rand(rng) for rng in rngs]
@test allunique(draws)

# Same seed should be deterministic
jp1 = JumpProcesses.resetted_jump_problem(jprob, UInt64(42))
jp2 = JumpProcesses.resetted_jump_problem(jprob, UInt64(42))
rng1 = jp1.jump_callback.discrete_callbacks[1].condition.rng
rng2 = jp2.jump_callback.discrete_callbacks[1].condition.rng
@test rand(rng1) == rand(rng2)
end

@testset "reset_jump_problem!: different seeds → different streams" begin
seeds = UInt64[100, 200, 300]
draws = map(seeds) do s
jp = make_ssa_jump_prob()
JumpProcesses.reset_jump_problem!(jp, s)
rand(jp.jump_callback.discrete_callbacks[1].condition.rng)
end
@test allunique(draws)
end

@testset "_derive_jump_seed: decorrelates from input seed" begin
seed = UInt64(12345)
derived = JumpProcesses._derive_jump_seed(seed)
# Derived seed should differ from input
@test derived != seed
# Should be deterministic
@test derived == JumpProcesses._derive_jump_seed(seed)
# Different inputs → different outputs
@test JumpProcesses._derive_jump_seed(UInt64(1)) != JumpProcesses._derive_jump_seed(UInt64(2))
end

# ==========================================================================
# 5. Variable-rate: jump_u thresholds are unique per trajectory
#
# For VR_FRM, each trajectory's first jump time is determined by the initial
# jump_u threshold (set to -randexp() by the VR_FRMEventCallback initialize).
# Distinct thresholds → distinct first event times, so we verify by checking
# that the second time point (first event) differs across serial trajectories.
# ==========================================================================

@testset "VR_FRM: jump_u thresholds unique per trajectory (EnsembleSerial)" begin
jprob = make_vr_jump_prob(VR_FRM())
sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleSerial();
trajectories = 3)
# The second time point is when the first variable-rate jump fires,
# directly reflecting the initial -randexp() threshold.
event_times = [sol.u[i].t[2] for i in 1:3]
@test allunique(event_times)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ end
@time @safetestset "Save_positions test" begin include("save_positions.jl") end
@time @safetestset "Ensemble Uniqueness test" begin include("ensemble_uniqueness.jl") end
@time @safetestset "Thread Safety test" begin include("thread_safety.jl") end
@time @safetestset "Ensemble Problem Tests" begin include("ensemble_problems.jl") end
@time @safetestset "A + B <--> C" begin include("reversible_binding.jl") end
@time @safetestset "Remake tests" begin include("remake_test.jl") end
@time @safetestset "ExtendedJumpArray remake tests" begin include("extended_jump_array_remake.jl") end
Expand Down
14 changes: 8 additions & 6 deletions test/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,15 @@ let

ode_prob = ODEProblem(ode_fxn, u0, tspan, p)
sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator = VR_FRM(), rng)
@test allunique(sjm_prob.prob.u0.jump_u)
u0old = copy(sjm_prob.prob.u0.jump_u)
# After callback initialize, integrator.u.jump_u should have unique thresholds
# that differ between sequential solves (RNG advances each time).
jump_u_old = zeros(length(sjm_prob.prob.u0.jump_u))
for i in 1:Nsims
sol = solve(sjm_prob, Tsit5(); saveat = tspan[2])
@test allunique(sjm_prob.prob.u0.jump_u)
@test all(u0old != sjm_prob.prob.u0.jump_u)
u0old .= sjm_prob.prob.u0.jump_u
integrator = init(sjm_prob, Tsit5(); saveat = tspan[2])
@test allunique(integrator.u.jump_u)
@test integrator.u.jump_u != jump_u_old
jump_u_old .= integrator.u.jump_u
solve!(integrator)
end
end

Expand Down
Loading