Skip to content

Commit 293bc94

Browse files
1187 Make set_initial_flows() of IDE model usable with and without age resolution (#1188)
- set_initial_flows() is now defined for both ConfirmedCasesDateEntry and ConfirmedCasesNoAgeEntry by using template argument. - Allow scale_confirmed_cases to bet set for each age group individually. - Added tests. Co-authored-by: lenaploetzke <[email protected]>
1 parent a06a573 commit 293bc94

8 files changed

Lines changed: 615 additions & 374 deletions

File tree

cpp/examples/ide_initialization.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "ide_secir/simulation.h"
2424
#include "ide_secir/parameters_io.h"
2525
#include "memilio/config.h"
26+
#include "memilio/io/epi_data.h"
2627
#include "memilio/utils/time_series.h"
2728
#include "memilio/utils/date.h"
2829
#include "memilio/math/eigen.h"
@@ -88,7 +89,21 @@ int main(int argc, char** argv)
8889
}
8990
else {
9091
// Use the real data for initialization.
91-
auto status = mio::isecir::set_initial_flows(model, dt, filename, mio::Date(2020, 12, 24));
92+
// Here we assume that the file contains data without age resolution, hence we use read_confirmed_cases_noage()
93+
// for reading the data and mio::ConfirmedCasesNoAgeEntry as EntryType in set_initial_flows().
94+
95+
auto status_read_data = mio::read_confirmed_cases_noage(filename);
96+
if (!status_read_data) {
97+
std::cout << "Error: " << status_read_data.error().formatted_message();
98+
return -1;
99+
}
100+
101+
std::vector<mio::ConfirmedCasesNoAgeEntry> rki_data = status_read_data.value();
102+
mio::CustomIndexArray<ScalarType, mio::AgeGroup> scale_confirmed_cases =
103+
mio::CustomIndexArray<ScalarType, mio::AgeGroup>(mio::AgeGroup(num_agegroups), 1.);
104+
105+
auto status = mio::isecir::set_initial_flows<mio::ConfirmedCasesNoAgeEntry>(
106+
model, dt, rki_data, mio::Date(2020, 12, 24), scale_confirmed_cases);
92107
if (!status) {
93108
std::cout << "Error: " << status.error().formatted_message();
94109
return -1;

cpp/models/ide_secir/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ add_library(ide_secir
66
simulation.cpp
77
parameters.h
88
parameters_io.h
9-
parameters_io.cpp
109
)
1110
target_link_libraries(ide_secir PUBLIC memilio)
1211
target_include_directories(ide_secir PUBLIC

cpp/models/ide_secir/model.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@ Model::Model(TimeSeries<ScalarType>&& transitions_init, CustomIndexArray<ScalarT
4646
, m_num_agegroups{num_agegroups}
4747

4848
{
49+
// Assert that input arguments for the total population have the correct size regarding
50+
// age groups.
51+
assert((size_t)m_N.size() == m_num_agegroups);
52+
4953
if (transitions.get_num_time_points() > 0) {
50-
// Add first time point in populations according to last time point in transitions which is where we start
54+
// Add first time point in m_populations according to last time point in m_transitions which is where we start
5155
// the simulation.
5256
populations.add_time_point<Eigen::VectorX<ScalarType>>(
5357
transitions.get_last_time(),
@@ -71,7 +75,7 @@ bool Model::check_constraints(ScalarType dt) const
7175
{
7276

7377
if (!((size_t)transitions.get_num_elements() == (size_t)InfectionTransition::Count * m_num_agegroups)) {
74-
log_error("A variable given for model construction is not valid. Number of elements in vector of"
78+
log_error("A variable given for model construction is not valid. Number of elements in vector of "
7579
"transitions does not match the required number.");
7680
return true;
7781
}
@@ -120,6 +124,21 @@ bool Model::check_constraints(ScalarType dt) const
120124
return true;
121125
}
122126

127+
if ((size_t)total_confirmed_cases.size() > 0 && (size_t)total_confirmed_cases.size() != m_num_agegroups) {
128+
log_error("Initialization failed. Number of elements in total_confirmed_cases does not match the number "
129+
"of age groups.");
130+
return true;
131+
}
132+
133+
if ((size_t)total_confirmed_cases.size() > 0) {
134+
for (AgeGroup group = AgeGroup(0); group < AgeGroup(m_num_agegroups); ++group) {
135+
if (total_confirmed_cases[group] < 0) {
136+
log_error("Initialization failed. One or more value of total_confirmed_cases is less than zero.");
137+
return true;
138+
}
139+
}
140+
}
141+
123142
return parameters.check_constraints();
124143
}
125144

cpp/models/ide_secir/model.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "memilio/config.h"
2727
#include "memilio/epidemiology/age_group.h"
2828
#include "memilio/utils/custom_index_array.h"
29+
#include "memilio/utils/date.h"
2930
#include "memilio/utils/time_series.h"
3031

3132
#include "vector"
@@ -34,6 +35,13 @@ namespace mio
3435
{
3536
namespace isecir
3637
{
38+
// Forward declaration of friend classes/functions of Model.
39+
class Model;
40+
class Simulation;
41+
template <typename EntryType>
42+
IOResult<void> set_initial_flows(Model& model, const ScalarType dt, const std::vector<EntryType> rki_data,
43+
const Date date, const CustomIndexArray<ScalarType, AgeGroup> scale_confirmed_cases);
44+
3745
class Model
3846
{
3947
using ParameterSet = Parameters;
@@ -130,6 +138,16 @@ class Model
130138
return m_initialization_method;
131139
}
132140

141+
/**
142+
* @brief Getter for number of age groups.
143+
*
144+
* @return Returns number of age groups.
145+
*/
146+
size_t get_num_agegroups() const
147+
{
148+
return m_num_agegroups;
149+
}
150+
133151
/**
134152
* @brief Setter for the tolerance used to calculate the maximum support of the TransitionDistributions.
135153
*
@@ -358,8 +376,10 @@ class Model
358376
friend class Simulation;
359377
// In set_initial_flows(), we compute initial flows based on RKI data using the (private) compute_flow() function
360378
// which is why it is defined as a friend function.
361-
friend IOResult<void> set_initial_flows(Model& model, ScalarType dt, std::string const& path, Date date,
362-
ScalarType scale_confirmed_cases);
379+
template <typename EntryType>
380+
friend IOResult<void> set_initial_flows(Model& model, const ScalarType dt, const std::vector<EntryType> rki_data,
381+
const Date date,
382+
const CustomIndexArray<ScalarType, AgeGroup> scale_confirmed_cases);
363383
};
364384

365385
} // namespace isecir

0 commit comments

Comments
 (0)