Skip to content

Commit 3bdd44e

Browse files
committed
write prototype for custom model
1 parent 7183651 commit 3bdd44e

16 files changed

Lines changed: 970 additions & 232 deletions

cpp/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ add_library(epidemiology
1818
epidemiology/utils/tensor_helpers.cpp
1919
epidemiology/math/euler.cpp
2020
epidemiology/math/euler.h
21+
epidemiology/model/ScalarType.h
22+
epidemiology/model/compartmentalmodel.h
23+
epidemiology/model/populations.h
24+
epidemiology/model/parameterset.h
25+
epidemiology/model/simulation.h
2126
epidemiology/secir/damping.cpp
2227
epidemiology/secir/damping.h
2328
epidemiology/secir/secir.cpp
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#ifndef SCALARTYPE_H
2+
#define SCALARTYPE_H
3+
4+
using ScalarType = double;
5+
6+
#endif // POPULATIONS_H
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#include "compartmentalmodel.h"
2+
3+
CompartmentalModel::CompartmentalModel()
4+
{
5+
6+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#ifndef COMPARTMENTALMODEL_H
2+
#define COMPARTMENTALMODEL_H
3+
4+
#include "epidemiology/model/ScalarType.h"
5+
#include <vector>
6+
#include <functional>
7+
#include <Eigen/Core>
8+
9+
namespace
10+
{
11+
12+
//some metaprogramming to transform a tuple into a parameter pack and use it as
13+
//an argument in a function.
14+
//Taken from https://stackoverflow.com/questions/7858817/unpacking-a-tuple-to-call-a-matching-function-pointer/9288547#9288547
15+
16+
template <typename Function, typename Tuple, size_t... I>
17+
auto call(Function f, Tuple t, std::index_sequence<I...>)
18+
{
19+
return f(std::get<I>(t)...);
20+
}
21+
22+
template <typename Function, typename Tuple>
23+
auto call(Function f, Tuple t)
24+
{
25+
static constexpr auto size = std::tuple_size<Tuple>::value;
26+
return call(f, t, std::make_index_sequence<size>{});
27+
}
28+
29+
} // namespace
30+
31+
namespace epi
32+
{
33+
34+
template <class Populations, class ParameterSet>
35+
struct CompartmentalModel {
36+
public:
37+
using FlowFunction =
38+
std::function<ScalarType(ParameterSet const& p, Eigen::Ref<const Eigen::VectorXd> y, double t)>;
39+
using Flow = std::tuple<typename Populations::Index, typename Populations::Index, FlowFunction>;
40+
41+
CompartmentalModel()
42+
{
43+
}
44+
45+
void add_flow(typename Populations::Index from, typename Populations::Index to, FlowFunction f)
46+
{
47+
flows.push_back(Flow(from, to, f));
48+
}
49+
50+
/**
51+
* @brief get_right_hand_side returns the right-hand-side f of the ODE dydt = f(y, t)
52+
* @param y the current state of the model as a flat array
53+
* @param t the current time
54+
* @param dydt a reference to the calculated output
55+
*/
56+
void get_right_hand_side(Eigen::Ref<const Eigen::VectorXd> y, double t, Eigen::Ref<Eigen::VectorXd> dydt) const
57+
{
58+
for (size_t i = 0; i < y.size(); ++i) {
59+
dydt[i] = 0;
60+
}
61+
for (auto flow : flows) {
62+
dydt[call(Populations::get_flat_index, std::get<0>(flow))] += std::get<2>(flow)(parameters, y, t);
63+
dydt[call(Populations::get_flat_index, std::get<1>(flow))] -= std::get<2>(flow)(parameters, y, t);
64+
}
65+
}
66+
67+
Eigen::VectorXd get_initial_values() const
68+
{
69+
return populations.get_compartments();
70+
}
71+
72+
Populations populations{};
73+
ParameterSet parameters{};
74+
75+
private:
76+
std::vector<Flow> flows{};
77+
};
78+
79+
} // namespace epi
80+
81+
#endif // COMPARTMENTALMODEL_H
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef PARAMETERSET_H
2+
#define PARAMETERSET_H
3+
4+
namespace epi
5+
{
6+
7+
//TODO: Implement this as a compile-time map.
8+
// I don't know how to implement compile time maps... :(
9+
template <class... Params>
10+
class ParameterSet
11+
{
12+
public:
13+
ParameterSet()
14+
{
15+
}
16+
17+
template <class Parameter>
18+
void set(ScalarType value)
19+
{
20+
//TODO!!!
21+
}
22+
23+
template <class Parameter>
24+
typename Parameter::Type const get() const
25+
{
26+
//TODO!!!
27+
}
28+
};
29+
30+
} // namespace epi
31+
32+
#endif // COMPARTMENTALMODEL_H

0 commit comments

Comments
 (0)