Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ end
# initialize: (cb, u, t, integrator)
function (c::VR_FRMEventCallback)(cb, u, t, integrator)
integrator.u.jump_u[c.idx] = -randexp(c.rng, typeof(integrator.t))
integrator.uprev.jump_u[c.idx] = integrator.u.jump_u[c.idx]
u_modified!(integrator, false)
u_modified!(integrator, true)
nothing
end

Expand Down
10 changes: 4 additions & 6 deletions test/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ function make_sde_vr_jump_prob(agg; rng = StableRNG(12345))
end

# Helpers
first_jump_time(traj) = traj.t[2]
# First time strictly after t[1], robust to initialization saves at t=0.
first_jump_time(traj) = traj.t[findfirst(>(traj.t[1]), traj.t)]

# ==========================================================================
# 1. Serial ensemble: sequential trajectories get different RNG streams
Expand Down Expand Up @@ -186,16 +187,13 @@ end
#
# For VR_FRM, each trajectory's first jump time is determined by the initial
# jump_u threshold (set to -randexp() by the VR_FRMEventCallback initialize).
# Distinct thresholds → distinct first event times, so we verify by checking
# that the second time point (first event) differs across serial trajectories.
# Distinct thresholds → distinct first event times.
# ==========================================================================

@testset "VR_FRM: jump_u thresholds unique per trajectory (EnsembleSerial)" begin
jprob = make_vr_jump_prob(VR_FRM())
sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleSerial();
trajectories = 3)
# The second time point is when the first variable-rate jump fires,
# directly reflecting the initial -randexp() threshold.
event_times = [sol.u[i].t[2] for i in 1:3]
event_times = [first_jump_time(sol.u[i]) for i in 1:3]
@test allunique(event_times)
end
3 changes: 2 additions & 1 deletion test/monte_carlo_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM(), rng)
monte_prob = EnsembleProblem(jump_prob)
sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3,
save_everystep = false, dt = 0.001, adaptive = false)
@test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2]
first_event(traj) = traj.t[findfirst(>(traj.t[1]), traj.t)]
@test first_event(sol.u[1]) != first_event(sol.u[2]) != first_event(sol.u[3])

jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_Direct(), rng)
monte_prob = EnsembleProblem(jump_prob)
Expand Down
12 changes: 8 additions & 4 deletions test/remake_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ let
@test jprob2.prob.u0.u === u0
sol = solve(jprob2, Tsit5())
u = sol[1, :]
@test length(u) > 2
@test all(>(u0[1]), u[3:end])
t = sol.t
first_nontstart = findfirst(>(t[1]), t)
@test !isnothing(first_nontstart)
@test all(>=(u0[1]), u[first_nontstart:end])
u0 = deepcopy(jprob2.prob.u0)
u0.u .= 0
jprob3 = remake(jprob2; u0)
Expand Down Expand Up @@ -126,6 +128,8 @@ let
@test jprob3.prob.u0 === u0eja
sol = solve(jprob3, Tsit5())
u = sol[1, :]
@test length(u) > 2
@test all(>(u0[1]), u[3:end])
t = sol.t
first_nontstart = findfirst(>(t[1]), t)
@test !isnothing(first_nontstart)
@test all(>=(u0[1]), u[first_nontstart:end])
end
4 changes: 2 additions & 2 deletions test/thread_safety.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DiffEqBase
using DiffEqBase, Test
using JumpProcesses, OrdinaryDiffEq
using StableRNGs
rng = StableRNG(12345)
Expand Down Expand Up @@ -30,7 +30,7 @@ let
prob = EnsembleProblem(jump_prob, prob_func = prob_func)
sol = solve(prob, Tsit5(), EnsembleThreads(), trajectories = 400,
save_everystep = false)
firstrx_time = [sol.u[i].t[2] for i in 1:length(sol)]
firstrx_time = [sol.u[i].t[findfirst(>(sol.u[i].t[1]), sol.u[i].t)] for i in 1:length(sol)]
@test allunique(firstrx_time)
end
end
Loading