Skip to content

Commit 60a99ba

Browse files
1133 Make simulation functions private in Model class of IDE model (#1166)
- Make member functions that are required for simulating private. - Define Simulation class and set_initial_flows() as friends. Co-authored-by: lenaploetzke <[email protected]>
1 parent 754675b commit 60a99ba

2 files changed

Lines changed: 202 additions & 192 deletions

File tree

cpp/models/ide_secir/model.cpp

Lines changed: 115 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,99 @@ Model::Model(TimeSeries<ScalarType>&& init, CustomIndexArray<ScalarType, AgeGrou
6767
}
6868
}
6969

70+
bool Model::check_constraints(ScalarType dt) const
71+
{
72+
73+
if (!((size_t)m_transitions.get_num_elements() == (size_t)InfectionTransition::Count * m_num_agegroups)) {
74+
log_error("A variable given for model construction is not valid. Number of elements in transition vector "
75+
"does not match the required number.");
76+
return true;
77+
}
78+
79+
for (AgeGroup group = AgeGroup(0); group < AgeGroup(m_num_agegroups); ++group) {
80+
81+
for (int i = 0; i < (int)InfectionState::Count; i++) {
82+
int index = get_state_flat_index(i, group);
83+
if (m_populations[0][index] < 0) {
84+
log_error("Initialization failed. Initial values for populations are less than zero.");
85+
return true;
86+
}
87+
}
88+
}
89+
90+
// It may be possible to run the simulation with fewer time points, but this number ensures that it is possible.
91+
if (m_transitions.get_num_time_points() < (Eigen::Index)std::ceil(get_global_support_max(dt) / dt)) {
92+
log_error("Initialization failed. Not enough time points for transitions given before start of "
93+
"simulation.");
94+
return true;
95+
}
96+
97+
for (AgeGroup group = AgeGroup(0); group < AgeGroup(m_num_agegroups); ++group) {
98+
99+
for (int i = 0; i < m_transitions.get_num_time_points(); i++) {
100+
for (int j = 0; j < (int)InfectionTransition::Count; j++) {
101+
int index = get_transition_flat_index(j, group);
102+
if (m_transitions[i][index] < 0) {
103+
log_error("Initialization failed. One or more initial value for transitions is less than zero.");
104+
return true;
105+
}
106+
}
107+
}
108+
}
109+
if (m_transitions.get_last_time() != m_populations.get_last_time()) {
110+
log_error("Last time point of TimeSeries for transitions does not match last time point of "
111+
"TimeSeries for "
112+
"compartments. Both of these time points have to agree for a sensible simulation.");
113+
return true;
114+
}
115+
116+
if (m_populations.get_num_time_points() != 1) {
117+
log_error("The TimeSeries for the compartments contains more than one time point. It is unclear how to "
118+
"initialize.");
119+
return true;
120+
}
121+
122+
return parameters.check_constraints();
123+
}
124+
125+
// Note that this function computes the global_support_max via the get_support_max() method and does not make use
126+
// of the vector m_transitiondistributions_support_max. This is because the global_support_max is already used in
127+
// check_constraints and we cannot ensure that the vector has already been computed when checking for constraints
128+
// (which usually happens before setting the initial flows and simulating).
129+
ScalarType Model::get_global_support_max(ScalarType dt) const
130+
{
131+
ScalarType global_support_max = 0.;
132+
ScalarType global_support_max_new = 0.;
133+
for (AgeGroup group = AgeGroup(0); group < AgeGroup(m_num_agegroups); ++group) {
134+
global_support_max_new = std::max(
135+
{parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::ExposedToInfectedNoSymptoms]
136+
.get_support_max(dt, m_tol),
137+
parameters
138+
.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedNoSymptomsToInfectedSymptoms]
139+
.get_support_max(dt, m_tol),
140+
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedNoSymptomsToRecovered]
141+
.get_support_max(dt, m_tol),
142+
parameters
143+
.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedSymptomsToInfectedSevere]
144+
.get_support_max(dt, m_tol),
145+
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedSymptomsToRecovered]
146+
.get_support_max(dt, m_tol),
147+
parameters
148+
.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedSevereToInfectedCritical]
149+
.get_support_max(dt, m_tol),
150+
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedSevereToRecovered]
151+
.get_support_max(dt, m_tol),
152+
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedCriticalToDead]
153+
.get_support_max(dt, m_tol),
154+
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedCriticalToRecovered]
155+
.get_support_max(dt, m_tol)});
156+
if (global_support_max_new > global_support_max) {
157+
global_support_max = global_support_max_new;
158+
}
159+
}
160+
return global_support_max;
161+
}
162+
70163
// ---- Functionality to calculate the sizes of the compartments for time t0. ----
71164
void Model::compute_compartment_from_flows(ScalarType dt, Eigen::Index idx_InfectionState, AgeGroup group,
72165
Eigen::Index idx_IncomingFlow, int idx_TransitionDistribution1,
@@ -384,7 +477,26 @@ void Model::flows_current_timestep(ScalarType dt)
384477
compute_flow(Eigen::Index(InfectionTransition::InfectedCriticalToRecovered),
385478
Eigen::Index(InfectionTransition::InfectedSevereToInfectedCritical), dt, group);
386479
}
387-
} // namespace isecir
480+
}
481+
482+
void Model::update_compartment_from_flow(InfectionState infectionState,
483+
std::vector<InfectionTransition> const& IncomingFlows,
484+
std::vector<InfectionTransition> const& OutgoingFlows, AgeGroup group)
485+
{
486+
int state_idx = get_state_flat_index(Eigen::Index(infectionState), group);
487+
488+
Eigen::Index num_time_points = m_populations.get_num_time_points();
489+
ScalarType updated_compartment = m_populations[num_time_points - 2][state_idx];
490+
for (const InfectionTransition& inflow : IncomingFlows) {
491+
int inflow_idx = get_transition_flat_index(Eigen::Index(inflow), group);
492+
updated_compartment += m_transitions.get_last_value()[inflow_idx];
493+
}
494+
for (const InfectionTransition& outflow : OutgoingFlows) {
495+
int outflow_idx = get_transition_flat_index(Eigen::Index(outflow), group);
496+
updated_compartment -= m_transitions.get_last_value()[outflow_idx];
497+
}
498+
m_populations.get_last_value()[state_idx] = updated_compartment;
499+
}
388500

389501
void Model::update_compartments()
390502
{
@@ -434,25 +546,6 @@ void Model::update_compartments()
434546
}
435547
}
436548

437-
void Model::update_compartment_from_flow(InfectionState infectionState,
438-
std::vector<InfectionTransition> const& IncomingFlows,
439-
std::vector<InfectionTransition> const& OutgoingFlows, AgeGroup group)
440-
{
441-
int state_idx = get_state_flat_index(Eigen::Index(infectionState), group);
442-
443-
Eigen::Index num_time_points = m_populations.get_num_time_points();
444-
ScalarType updated_compartment = m_populations[num_time_points - 2][state_idx];
445-
for (const InfectionTransition& inflow : IncomingFlows) {
446-
int inflow_idx = get_transition_flat_index(Eigen::Index(inflow), group);
447-
updated_compartment += m_transitions.get_last_value()[inflow_idx];
448-
}
449-
for (const InfectionTransition& outflow : OutgoingFlows) {
450-
int outflow_idx = get_transition_flat_index(Eigen::Index(outflow), group);
451-
updated_compartment -= m_transitions.get_last_value()[outflow_idx];
452-
}
453-
m_populations.get_last_value()[state_idx] = updated_compartment;
454-
}
455-
456549
void Model::compute_forceofinfection(ScalarType dt, bool initialization)
457550
{
458551

@@ -532,8 +625,9 @@ void Model::compute_forceofinfection(ScalarType dt, bool initialization)
532625
m_forceofinfection[i] += divNj * sum;
533626
}
534627
}
535-
} // namespace mio
628+
}
536629

630+
// ---- Functionality to set vectors with necessary information regarding TransitionDistributions. ----
537631
void Model::set_transitiondistributions_support_max(ScalarType dt)
538632
{
539633
m_transitiondistributions_support_max = CustomIndexArray<std::vector<ScalarType>, AgeGroup>(
@@ -618,43 +712,5 @@ void Model::set_transitiondistributions_in_forceofinfection(ScalarType dt)
618712
}
619713
}
620714

621-
// Note that this function computes the global_support_max via the get_support_max() method and does not make use
622-
// of the vector m_transitiondistributions_support_max. This is because the global_support_max is already used in
623-
// check_constraints and we cannot ensure that the vector has already been computed when checking for constraints
624-
// (which usually happens before setting the initial flows and simulating).
625-
ScalarType Model::get_global_support_max(ScalarType dt) const
626-
{
627-
ScalarType global_support_max = 0.;
628-
ScalarType global_support_max_new = 0.;
629-
for (AgeGroup group = AgeGroup(0); group < AgeGroup(m_num_agegroups); ++group) {
630-
global_support_max_new = std::max(
631-
{parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::ExposedToInfectedNoSymptoms]
632-
.get_support_max(dt, m_tol),
633-
parameters
634-
.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedNoSymptomsToInfectedSymptoms]
635-
.get_support_max(dt, m_tol),
636-
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedNoSymptomsToRecovered]
637-
.get_support_max(dt, m_tol),
638-
parameters
639-
.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedSymptomsToInfectedSevere]
640-
.get_support_max(dt, m_tol),
641-
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedSymptomsToRecovered]
642-
.get_support_max(dt, m_tol),
643-
parameters
644-
.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedSevereToInfectedCritical]
645-
.get_support_max(dt, m_tol),
646-
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedSevereToRecovered]
647-
.get_support_max(dt, m_tol),
648-
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedCriticalToDead]
649-
.get_support_max(dt, m_tol),
650-
parameters.get<TransitionDistributions>()[group][(int)InfectionTransition::InfectedCriticalToRecovered]
651-
.get_support_max(dt, m_tol)});
652-
if (global_support_max_new > global_support_max) {
653-
global_support_max = global_support_max_new;
654-
}
655-
}
656-
return global_support_max;
657-
}
658-
659715
} // namespace isecir
660716
} // namespace mio

0 commit comments

Comments
 (0)