Skip to content

Commit f42b89a

Browse files
610 implement initialization scheme for flows (#952)
- Added functionality to calculate initial flows based on real data for an IDE-SECIR model. - Added functionality to get mean values to StateAgeFunction. - Added tests and an example. Co-authored-by: Anna Wendler <[email protected]>
1 parent 8d0d615 commit f42b89a

14 files changed

Lines changed: 867 additions & 97 deletions

cpp/examples/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,9 @@ if(MEMILIO_HAS_HDF5)
116116
target_link_libraries(ode_secir_save_results_example PRIVATE memilio ode_secir)
117117
target_compile_options(ode_secir_save_results_example PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS})
118118
endif()
119+
120+
if(MEMILIO_HAS_JSONCPP)
121+
add_executable(ide_initialization_example ide_initialization.cpp)
122+
target_link_libraries(ide_initialization_example PRIVATE memilio ide_secir)
123+
target_compile_options(ide_initialization_example PRIVATE ${MEMILIO_CXX_FLAGS_ENABLE_WARNING_ERRORS})
124+
endif()
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (C) 2020-2024 MEmilio
3+
*
4+
* Authors: Lena Ploetzke, Anna Wendler
5+
*
6+
* Contact: Martin J. Kuehn <[email protected]>
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this file except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*/
20+
21+
#include "ide_secir/model.h"
22+
#include "ide_secir/infection_state.h"
23+
#include "ide_secir/simulation.h"
24+
#include "ide_secir/parameters_io.h"
25+
#include "memilio/config.h"
26+
#include "memilio/utils/time_series.h"
27+
#include "memilio/utils/date.h"
28+
#include "memilio/math/eigen.h"
29+
#include <string>
30+
#include <vector>
31+
#include <iostream>
32+
33+
/**
34+
* @brief Function to check the parameters provided in the command line.
35+
*/
36+
std::string setup(int argc, char** argv)
37+
{
38+
if (argc == 2) {
39+
std::cout << "Using file " << argv[1] << "." << std::endl;
40+
return (std::string)argv[1];
41+
}
42+
else {
43+
if (argc > 2) {
44+
mio::log_warning("Too many arguments given.");
45+
}
46+
else {
47+
mio::log_warning("No arguments given.");
48+
}
49+
return "";
50+
}
51+
}
52+
53+
int main(int argc, char** argv)
54+
{
55+
// This is a simple example to demonstrate how to set initial data for the IDE-SECIR model using real data.
56+
// A default initialization is used if no filename is provided in the command line.
57+
// Have a look at the documentation of the set_initial_flows() function in models/ide_secir/parameters_io.h for a
58+
// description of how to download suitable data.
59+
// A valid filename could be for example "../../data/pydata/Germany/cases_all_germany_ma7.json" if the functionality to download real data is used.
60+
// The default parameters of the IDE-SECIR model are used, so that the simulation results are not realistic and are for demonstration purpose only.
61+
62+
// Initialize model.
63+
ScalarType total_population = 80 * 1e6;
64+
ScalarType deaths = 0; // The number of deaths will be overwritten if real data is used for initialization.
65+
ScalarType dt = 0.5;
66+
mio::isecir::Model model(mio::TimeSeries<ScalarType>((int)mio::isecir::InfectionTransition::Count),
67+
total_population, deaths);
68+
69+
// Check provided parameters.
70+
std::string filename = setup(argc, argv);
71+
if (filename.empty()) {
72+
std::cout << "You did not provide a valid filename. A default initialization is used." << std::endl;
73+
74+
using Vec = mio::TimeSeries<ScalarType>::Vector;
75+
mio::TimeSeries<ScalarType> init((int)mio::isecir::InfectionTransition::Count);
76+
init.add_time_point<Eigen::VectorXd>(-7., Vec::Constant((int)mio::isecir::InfectionTransition::Count, 1. * dt));
77+
while (init.get_last_time() < -dt / 2) {
78+
init.add_time_point(init.get_last_time() + dt,
79+
Vec::Constant((int)mio::isecir::InfectionTransition::Count, 1. * dt));
80+
}
81+
model.m_transitions = init;
82+
}
83+
else {
84+
// Use the real data for initialization.
85+
auto status = mio::isecir::set_initial_flows(model, dt, filename, mio::Date(2020, 12, 24));
86+
if (!status) {
87+
std::cout << "Error: " << status.error().formatted_message();
88+
return -1;
89+
}
90+
}
91+
92+
// Carry out simulation.
93+
mio::isecir::Simulation sim(model, dt);
94+
sim.advance(2.);
95+
96+
// Print results.
97+
sim.get_transitions().print_table({"S->E", "E->C", "C->I", "C->R", "I->H", "I->R", "H->U", "H->R", "U->D", "U->R"},
98+
16, 8);
99+
100+
return 0;
101+
}

cpp/examples/ide_secir.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,13 @@ int main()
7373
mio::SmootherCosine smoothcos(2.0);
7474
mio::StateAgeFunctionWrapper delaydistribution(smoothcos);
7575
std::vector<mio::StateAgeFunctionWrapper> vec_delaydistrib(num_transitions, delaydistribution);
76-
vec_delaydistrib[(int)mio::isecir::InfectionTransition::SusceptibleToExposed].set_parameter(3.0);
76+
// TransitionDistribution is not used for SusceptibleToExposed. Therefore, the parameter can be set to any value.
77+
vec_delaydistrib[(int)mio::isecir::InfectionTransition::SusceptibleToExposed].set_parameter(-1.);
7778
vec_delaydistrib[(int)mio::isecir::InfectionTransition::InfectedNoSymptomsToInfectedSymptoms].set_parameter(4.0);
7879
model.parameters.set<mio::isecir::TransitionDistributions>(vec_delaydistrib);
7980

80-
std::vector<ScalarType> vec_prob((int)mio::isecir::InfectionTransition::Count, 0.5);
81+
std::vector<ScalarType> vec_prob(num_transitions, 0.5);
82+
// The following probabilities must be 1, as there is no other way to go.
8183
vec_prob[Eigen::Index(mio::isecir::InfectionTransition::SusceptibleToExposed)] = 1;
8284
vec_prob[Eigen::Index(mio::isecir::InfectionTransition::ExposedToInfectedNoSymptoms)] = 1;
8385
model.parameters.set<mio::isecir::TransitionProbabilities>(vec_prob);

cpp/memilio/epidemiology/state_age_function.h

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ namespace mio
5858
* decreasing. This is no limitation as the support is only needed for StateAgeFunctions of Type a) as given above.
5959
* For classes of type b) a dummy implementation logging an error and returning -2 for get_support_max() should be implemented.
6060
*
61+
* The get_mean method is virtual and implements a basic version to determine the mean value of the StateAgeFunction.
62+
* The base class implementation uses the fact that the StateAgeFunction is a survival function
63+
* (i.e. 1-CDF for any cumulative distribution function CDF).
64+
* Therefore, the base class implementation should only be used for StateAgeFunction%s of type a).
65+
* For some derived classes there is a more efficient way (see e.g., ExponentialDecay) to do this which is
66+
* why it can be overridden.
67+
*
6168
* See ExponentialDecay, SmootherCosine and ConstantFunction for examples of derived classes.
6269
*/
6370
struct StateAgeFunction {
@@ -69,8 +76,10 @@ struct StateAgeFunction {
6976
*/
7077
StateAgeFunction(ScalarType init_parameter)
7178
: m_parameter{init_parameter}
72-
, m_support_max{-1.} // initialize support maximum as not set
73-
, m_support_tol{-1.} // initialize support tolerance as not set
79+
, m_mean{-1.} // Initialize mean as not set.
80+
, m_mean_tol{-1.} // Initialize tolerance for computation of mean as not set.
81+
, m_support_max{-1.} // Initialize support maximum as not set.
82+
, m_support_tol{-1.} // Initialize tolerance for computation of support as not set.
7483
{
7584
}
7685

@@ -144,6 +153,7 @@ struct StateAgeFunction {
144153
m_parameter = new_parameter;
145154

146155
m_support_max = -1.;
156+
m_mean = -1;
147157
}
148158

149159
/**
@@ -176,6 +186,37 @@ struct StateAgeFunction {
176186
return m_support_max;
177187
}
178188

189+
/**
190+
* @brief Computes the mean value of the function using the time step size dt and some tolerance tol.
191+
*
192+
* This is a basic version to determine the mean value of a survival function
193+
* through numerical integration of the integral that describes the expected value.
194+
* This basic implementation is only valid if the StateAgeFunction is of type a). Otherwise it should be overridden.
195+
*
196+
* For some specific derivations of StateAgeFunction%s there are more efficient ways to determine the
197+
* the mean value which is why this member function is virtual and can be overridden (see, e.g., ExponentialDecay).
198+
* The mean value is only needed for StateAgeFunction%s that are used as TransitionDistribution%s.
199+
*
200+
* @param[in] dt Time step size used for the numerical integration.
201+
* @param[in] tol The maximum support used for numerical integration is calculated using this tolerance.
202+
* @return ScalarType mean value.
203+
*/
204+
virtual ScalarType get_mean(ScalarType dt = 1., ScalarType tol = 1e-10)
205+
{
206+
if (!floating_point_equal(m_mean_tol, tol, 1e-14) || floating_point_equal(m_mean, -1., 1e-14)) {
207+
// Integration using Trapezoidal rule.
208+
ScalarType mean = 0.5 * dt * eval(0 * dt);
209+
ScalarType supp_max_idx = std::ceil(get_support_max(dt, tol) / dt);
210+
for (int i = 1; i < supp_max_idx; i++) {
211+
mean += dt * eval(i * dt);
212+
}
213+
214+
m_mean = mean;
215+
m_mean_tol = tol;
216+
}
217+
return m_mean;
218+
}
219+
179220
/**
180221
* @brief Get type of StateAgeFunction, i.e.which derived class is used.
181222
*
@@ -205,6 +246,8 @@ struct StateAgeFunction {
205246
virtual StateAgeFunction* clone_impl() const = 0;
206247

207248
ScalarType m_parameter; ///< Parameter for function in derived class.
249+
ScalarType m_mean; ///< Mean value of the function.
250+
ScalarType m_mean_tol; ///< Tolerance for computation of the mean.
208251
ScalarType m_support_max; ///< Maximum of the support of the function.
209252
ScalarType m_support_tol; ///< Tolerance for computation of the support.
210253
};
@@ -241,6 +284,22 @@ struct ExponentialDecay : public StateAgeFunction {
241284
return std::exp(-m_parameter * state_age);
242285
}
243286

287+
/**
288+
* @brief Computes the mean value of the function.
289+
*
290+
* For ExponentialDecay, the mean value is the reciprocal of the function parameter.
291+
*
292+
* @param[in] dt Time step size used for the numerical integration (unused for ExponentialDecay).
293+
* @param[in] tol The maximum support used for numerical integration is calculated using this tolerance (unused for ExponentialDecay).
294+
* @return ScalarType mean value.
295+
*/
296+
ScalarType get_mean(ScalarType dt = 1., ScalarType tol = 1e-10) override
297+
{
298+
unused(dt);
299+
unused(tol);
300+
return 1. / m_parameter;
301+
}
302+
244303
protected:
245304
/**
246305
* @brief Implements clone for ExponentialDecay.
@@ -298,6 +357,10 @@ struct SmootherCosine : public StateAgeFunction {
298357
return m_support_max;
299358
}
300359

360+
// TODO: There is also a closed form for the mean value of Smoothercosine: 0.5*m_parameter.
361+
// However, a StateAgeFunction that uses the default implementation is required for testing purposes.
362+
// Therefore, the closed form is only used for comparison in the tests.
363+
// If another StateAgeFunction is implemented that uses the default implementation, the function get_mean() should be overwritten here.
301364
protected:
302365
/**
303366
* @brief Clones unique pointer to a StateAgeFunction.
@@ -365,6 +428,22 @@ struct ConstantFunction : public StateAgeFunction {
365428
return m_support_max;
366429
}
367430

431+
/**
432+
* @brief Computes the mean value of the function.
433+
*
434+
* For ConstantFunction, the mean value is the function parameter.
435+
*
436+
* @param[in] dt Time step size used for the numerical integration (unused for ConstantFunction).
437+
* @param[in] tol The maximum support used for numerical integration is calculated using this tolerance (unused for ConstantFunction).
438+
* @return ScalarType mean value.
439+
*/
440+
ScalarType get_mean(ScalarType dt = 1., ScalarType tol = 1e-10) override
441+
{
442+
unused(dt);
443+
unused(tol);
444+
return m_parameter;
445+
}
446+
368447
protected:
369448
/**
370449
* @brief Clones unique pointer to a StateAgeFunction.
@@ -460,7 +539,7 @@ struct StateAgeFunctionWrapper {
460539
/**
461540
* @brief Get type of StateAgeFunction, i.e. which derived class is used.
462541
*
463-
* @param[out] string
542+
* @return string
464543
*/
465544
std::string get_state_age_function_type() const
466545
{
@@ -498,11 +577,30 @@ struct StateAgeFunctionWrapper {
498577
m_function->set_parameter(new_parameter);
499578
}
500579

580+
/**
581+
* @brief Get the m_support_max object of m_function.
582+
*
583+
* @param[in] dt Time step size at which function will be evaluated.
584+
* @param[in] tol Tolerance used for cutting the support if the function value falls below.
585+
* @return ScalarType m_support_max
586+
*/
501587
ScalarType get_support_max(ScalarType dt, ScalarType tol = 1e-10) const
502588
{
503589
return m_function->get_support_max(dt, tol);
504590
}
505591

592+
/**
593+
* @brief Get the m_mean object of m_function.
594+
*
595+
* @param[in] dt Time step size used for the numerical integration.
596+
* @param[in] tol The maximum support used for numerical integration is calculated using this tolerance.
597+
* @return ScalarType m_mean
598+
*/
599+
ScalarType get_mean(ScalarType dt = 1., ScalarType tol = 1e-10) const
600+
{
601+
return m_function->get_mean(dt, tol);
602+
}
603+
506604
private:
507605
std::unique_ptr<StateAgeFunction> m_function; ///< Stores StateAgeFunction that is used in Wrapper.
508606
};

cpp/models/ide_secir/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ add_library(ide_secir
33
model.h
44
model.cpp
55
simulation.h
6-
simulation.cpp
6+
simulation.cpp
77
parameters.h
8+
parameters_io.h
9+
parameters_io.cpp
810
)
911
target_link_libraries(ide_secir PUBLIC memilio)
1012
target_include_directories(ide_secir PUBLIC

0 commit comments

Comments
 (0)