Skip to content

Commit f7b8d38

Browse files
HenrZuMaxBetzDLR
andauthored
859 Use FlowSimulation in Python Bindings (#860)
Co-authored-by: MaxBetzDLR <[email protected]>
1 parent 3d59234 commit f7b8d38

10 files changed

Lines changed: 192 additions & 5 deletions

File tree

cpp/memilio/compartments/flow_model.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ using filtered_tuple_t = decltype(filter_tuple<OmittedTag>(std::declval<Tuple>()
5555

5656
// Remove all occurrences of OmittedTag from the types in an Index = IndexTemplate<types...>.
5757
template <class OmittedTag, template <class...> class IndexTemplate, class Index>
58-
using filtered_index_t = decltype(as_index<IndexTemplate>(
59-
std::declval<filtered_tuple_t<OmittedTag, decltype(as_tuple(std::declval<Index>()))>>()));
58+
using filtered_index_t = decltype(
59+
as_index<IndexTemplate>(std::declval<filtered_tuple_t<OmittedTag, decltype(as_tuple(std::declval<Index>()))>>()));
6060

6161
} //namespace details
6262

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2024 MEmilio
3+
#
4+
# Authors: Henrik Zunker
5+
#
6+
# Contact: Martin J. Kuehn <[email protected]>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
import argparse
21+
22+
import numpy as np
23+
24+
from memilio.simulation import Damping
25+
from memilio.simulation.oseir import Index_InfectionState
26+
from memilio.simulation.oseir import InfectionState as State
27+
from memilio.simulation.oseir import (Model, interpolate_simulation_result,
28+
simulate_flows)
29+
30+
31+
def run_oseir_simulation():
32+
"""
33+
Runs the c++ ode seir model using a flow simulation
34+
"""
35+
36+
# Define population of age groups
37+
populations = [83000]
38+
39+
days = 100 # number of days to simulate
40+
dt = 0.1
41+
42+
# Initialize Parameters
43+
model = Model()
44+
45+
# Compartment transition duration
46+
model.parameters.TimeExposed.value = 5.2
47+
model.parameters.TimeInfected.value = 6.
48+
49+
# Compartment transition propabilities
50+
model.parameters.TransmissionProbabilityOnContact.value = 1.
51+
52+
# Initial number of people in each compartment
53+
model.populations[Index_InfectionState(State.Exposed)] = 100
54+
model.populations[Index_InfectionState(State.Infected)] = 50
55+
model.populations[Index_InfectionState(State.Recovered)] = 10
56+
model.populations.set_difference_from_total(
57+
(Index_InfectionState(State.Susceptible)), populations[0])
58+
59+
# model.parameters.ContactPatterns = ContactMatrix(np.r_[0.5])
60+
model.parameters.ContactPatterns.baseline = np.ones((1, 1))
61+
model.parameters.ContactPatterns.minimum = np.zeros((1, 1))
62+
model.parameters.ContactPatterns.add_damping(
63+
Damping(coeffs=np.r_[0.9], t=30.0, level=0, type=0))
64+
65+
# Check logical constraints to parameters
66+
model.check_constraints()
67+
68+
# Run flow simulation
69+
(result, flows) = simulate_flows(0, days, dt, model)
70+
71+
print(result.print_table(["S", "E", "I", "R"], 16, 5))
72+
print(flows.print_table(["S->E", "E->I", "I->R"], 16, 5))
73+
74+
75+
if __name__ == "__main__":
76+
arg_parser = argparse.ArgumentParser(
77+
'ode seir model with flow simulation',
78+
description='Simple example demonstrating the setup and flow simulation of the OSEIR model.')
79+
args = arg_parser.parse_args()
80+
run_oseir_simulation()

pycode/examples/simulation/oseir_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def run_oseir_simulation():
7575

7676
if __name__ == "__main__":
7777
arg_parser = argparse.ArgumentParser(
78-
'secir_simple',
78+
'oseir_simple',
7979
description='Simple example demonstrating the setup and simulation of the OSEIR model.')
8080
args = arg_parser.parse_args()
8181
run_oseir_simulation()
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (C) 2020-2024 MEmilio
3+
*
4+
* Authors: Henrik Zunker
5+
*
6+
* Contact: Martin J. Kuehn <[email protected]>
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this file except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*/
20+
#ifndef PYMIO_FLOW_SIMULATION_H
21+
#define PYMIO_FLOW_SIMULATION_H
22+
23+
#include "memilio/compartments/flow_simulation.h"
24+
25+
#include "pybind11/pybind11.h"
26+
27+
namespace pymio
28+
{
29+
30+
template <class Model>
31+
void bind_Flow_Simulation(pybind11::module_& m)
32+
{
33+
pybind11::class_<mio::FlowSimulation<Model>>(m, "FlowSimulation")
34+
.def(pybind11::init<const Model&, double, double>(), pybind11::arg("model"), pybind11::arg("t0") = 0,
35+
pybind11::arg("dt") = 0.1)
36+
.def_property_readonly("result",
37+
pybind11::overload_cast<>(&mio::FlowSimulation<Model>::get_result, pybind11::const_),
38+
pybind11::return_value_policy::reference_internal)
39+
.def_property_readonly("flows", &mio::FlowSimulation<Model>::get_flows,
40+
pybind11::return_value_policy::reference_internal)
41+
.def("advance", &mio::FlowSimulation<Model>::advance, pybind11::arg("tmax"));
42+
}
43+
44+
} // namespace pymio
45+
46+
#endif //PYMIO_FLOW_SIMULATION_H

pycode/memilio-simulation/memilio/simulation/oseir.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "utils/custom_index_array.h"
2424
#include "utils/parameter_set.h"
2525
#include "compartments/simulation.h"
26+
#include "compartments/flow_simulation.h"
2627
#include "compartments/compartmentalmodel.h"
2728
#include "epidemiology/populations.h"
2829
#include "ode_seir/model.h"
@@ -82,5 +83,13 @@ PYBIND11_MODULE(_simulation_oseir, m)
8283
},
8384
"Simulates a oseir from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"));
8485

86+
m.def(
87+
"simulate_flows",
88+
[](double t0, double tmax, double dt, const mio::oseir::Model& model) {
89+
return mio::simulate_flows(t0, tmax, dt, model);
90+
},
91+
"Simulates a oseir with flows from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"),
92+
py::arg("model"));
93+
8594
m.attr("__version__") = "dev";
8695
}

pycode/memilio-simulation/memilio/simulation/secir.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "memilio/config.h"
2222
#include "pybind_util.h"
2323
#include "compartments/simulation.h"
24+
#include "compartments/flow_simulation.h"
2425
#include "compartments/compartmentalmodel.h"
2526
#include "epidemiology/populations.h"
2627
#include "utils/custom_index_array.h"
@@ -221,6 +222,14 @@ PYBIND11_MODULE(_simulation_secir, m)
221222
},
222223
"Simulates a Secir Model1 from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"), py::arg("model"));
223224

225+
m.def(
226+
"simulate_flows",
227+
[](double t0, double tmax, double dt, const mio::osecir::Model& model) {
228+
return mio::simulate_flows(t0, tmax, dt, model);
229+
},
230+
"Simulates a Secir model with flows from t0 to tmax.", py::arg("t0"), py::arg("tmax"), py::arg("dt"),
231+
py::arg("model"));
232+
224233
pymio::bind_ModelNode<mio::osecir::Model>(m, "ModelNode");
225234
pymio::bind_SimulationNode<mio::osecir::Simulation<>>(m, "SimulationNode");
226235
pymio::bind_ModelGraph<mio::osecir::Model>(m, "ModelGraph");

pycode/memilio-simulation/memilio/simulation/utils/time_series.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
*/
2020
#include "utils/time_series.h"
2121
#include "memilio/utils/time_series.h"
22+
#include <pybind11/pybind11.h>
23+
#include <pybind11/stl.h>
2224

2325
#include "pybind11/eigen.h"
2426

@@ -60,6 +62,13 @@ void bind_time_series(py::module_& m, std::string const& name)
6062
}
6163
},
6264
py::is_operator(), py::arg("index"), py::arg("v"))
65+
.def("print_table",
66+
[](const mio::TimeSeries<double>& self, const std::vector<std::string>& column_labels, size_t width,
67+
size_t precision) {
68+
std::ostringstream oss;
69+
self.print_table(column_labels, width, precision, oss);
70+
return oss.str();
71+
})
6372
.def("add_time_point",
6473
[](mio::TimeSeries<double>& self) {
6574
return self.add_time_point();

pycode/memilio-simulation/memilio/simulation_test/test_oseir.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from memilio.simulation import Damping
2525
from memilio.simulation.oseir import Index_InfectionState
2626
from memilio.simulation.oseir import InfectionState as State
27-
from memilio.simulation.oseir import Model, simulate
27+
from memilio.simulation.oseir import Model, simulate, simulate_flows
2828

2929

3030
class Test_oseir_integration(unittest.TestCase):
@@ -57,6 +57,19 @@ def test_simulate_simple(self):
5757
self.assertAlmostEqual(result.get_time(1), 0.1)
5858
self.assertAlmostEqual(result.get_last_time(), 100.)
5959

60+
def test_flow_simulation_simple(self):
61+
flow_sim_results = simulate_flows(
62+
t0=0., tmax=100., dt=0.1, model=self.model)
63+
flows = flow_sim_results[1]
64+
self.assertEqual(flows.get_time(0), 0.)
65+
self.assertEqual(flows.get_last_time(), 100.)
66+
self.assertEqual(len(flows.get_last_value()), 3)
67+
68+
compartments = flow_sim_results[0]
69+
self.assertEqual(compartments.get_time(0), 0.)
70+
self.assertEqual(compartments.get_last_time(), 100.)
71+
self.assertEqual(len(compartments.get_last_value()), 4)
72+
6073
def test_check_constraints_parameters(self):
6174

6275
model = Model()

pycode/memilio-simulation/memilio/simulation_test/test_secir.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from memilio.simulation import AgeGroup, ContactMatrix, Damping, UncertainContactMatrix
2525
from memilio.simulation.secir import Index_InfectionState
2626
from memilio.simulation.secir import InfectionState as State
27-
from memilio.simulation.secir import Model, Simulation, simulate
27+
from memilio.simulation.secir import Model, Simulation, simulate, simulate_flows
2828

2929

3030
class Test_secir_integration(unittest.TestCase):
@@ -74,6 +74,19 @@ def test_simulate_simple(self):
7474
self.assertAlmostEqual(result.get_time(1), 0.1)
7575
self.assertAlmostEqual(result.get_last_time(), 100.)
7676

77+
def test_flow_simulation_simple(self):
78+
flow_sim_results = simulate_flows(
79+
t0=0., tmax=100., dt=0.1, model=self.model)
80+
flows = flow_sim_results[1]
81+
self.assertEqual(flows.get_time(0), 0.)
82+
self.assertEqual(flows.get_last_time(), 100.)
83+
self.assertEqual(len(flows.get_last_value()), 15)
84+
85+
compartments = flow_sim_results[0]
86+
self.assertEqual(compartments.get_time(0), 0.)
87+
self.assertEqual(compartments.get_last_time(), 100.)
88+
self.assertEqual(len(compartments.get_last_value()), 10)
89+
7790
def test_simulation_simple(self):
7891
sim = Simulation(self.model, t0=0., dt=0.1)
7992
sim.advance(tmax=100.)

pycode/memilio-simulation/memilio/simulation_test/test_time_series.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def test_ndarray(self):
5353
assert_array_equal(ts.get_last_value(), np.r_[1.1, 1.2])
5454
assert_array_equal(ts.get_last_time(), 1.0)
5555

56+
def test_print_table(self):
57+
ts = mio.TimeSeries(1)
58+
ts.add_time_point(2, np.r_[1])
59+
ts.add_time_point(3.5, np.r_[2])
60+
output = ts.print_table(["a", "b"], 2, 2)
61+
self.assertEqual(
62+
output, 'Time a \n2.00 1.00\n3.50 2.00\n')
63+
5664

5765
if __name__ == '__main__':
5866
unittest.main()

0 commit comments

Comments
 (0)