Skip to content

Commit 451e147

Browse files
xsaschakocharlie0614reneSchm
authored
635 implement example for abm output object and time since transmission variable (#690)
Co-authored-by: Carlotta Gerstein <[email protected]> Co-authored-by: reneSchm <[email protected]>
1 parent 6d01187 commit 451e147

25 files changed

Lines changed: 764 additions & 439 deletions

cpp/benchmarks/abm.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@ mio::abm::Simulation make_simulation(size_t num_persons, std::initializer_list<u
5555
if (mio::UniformDistribution<double>::get_instance()(prng, 0.0, 1.0) < pct_infected) {
5656
auto state = mio::abm::InfectionState(
5757
mio::UniformIntDistribution<int>::get_instance()(prng, 1, int(mio::abm::InfectionState::Count) - 1));
58-
auto infection =
59-
mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, person.get_age(),
60-
world.parameters, mio::abm::TimePoint(0), state);
58+
auto infection = mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, person.get_age(),
59+
world.parameters, mio::abm::TimePoint(0), state);
6160
person.add_new_infection(std::move(infection));
6261
}
6362

@@ -68,8 +67,7 @@ mio::abm::Simulation make_simulation(size_t num_persons, std::initializer_list<u
6867
}
6968

7069
//masks at locations
71-
for (auto& loc : world.get_locations())
72-
{
70+
for (auto& loc : world.get_locations()) {
7371
//some % of locations require masks
7472
//skip homes so persons always have a place to go, simulation might break otherwise
7573
auto pct_require_mask = 0.2;
@@ -128,16 +126,20 @@ void abm_benchmark(benchmark::State& state, size_t num_persons, std::initializer
128126
state.ResumeTiming();
129127

130128
//simulated time should be long enough to have full infection runs and migration to every location
131-
auto final_time = sim.get_time() + mio::abm::days(10);
129+
auto final_time = sim.get_time() + mio::abm::days(10);
132130
sim.advance(final_time);
133131

134132
//debug output can be enabled to check for unexpected results (e.g. infections dieing out)
135133
//normally should have no significant effect on runtime
136134
const bool monitor_infection_activity = false;
137135
if constexpr (monitor_infection_activity) {
138136
std::cout << "num_persons = " << num_persons << "\n";
139-
std::cout << sim.get_result()[0].transpose() << "\n";
140-
std::cout << sim.get_result().get_last_value().transpose() << "\n";
137+
for (auto inf_state = 0; inf_state < (int)mio::abm::InfectionState::Count; inf_state++) {
138+
std::cout << "inf_state = " << inf_state << ", sum = "
139+
<< sim.get_world().get_subpopulation_combined(sim.get_time(),
140+
mio::abm::InfectionState(inf_state))
141+
<< "\n";
142+
}
141143
}
142144
}
143145
}

cpp/examples/abm_minimal.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <fstream>
2323
#include <string>
2424
#include <iostream>
25+
#include "abm/common_abm_loggers.h"
2526

2627
int main()
2728
{
@@ -149,19 +150,22 @@ int main()
149150
// Set start and end time for the simulation.
150151
auto t0 = mio::abm::TimePoint(0);
151152
auto tmax = t0 + mio::abm::days(10);
153+
auto sim = mio::abm::Simulation(t0, std::move(world));
152154

153-
// Create and run the simualtion for the scenario defined above.
154-
auto sim = mio::abm::Simulation(t0, std::move(world));
155-
sim.advance(tmax);
155+
// Create a history object to store the time series of the infection states.
156+
mio::History<mio::abm::TimeSeriesWriter, mio::abm::LogInfectionState> historyTimeSeries{
157+
Eigen::Index(mio::abm::InfectionState::Count)};
156158

157-
std::ofstream outfile("abm_minimal.txt");
159+
// Run the simulation until tmax with the history object.
160+
sim.advance(tmax, historyTimeSeries);
158161

159162
// The results are written into the file "abm_minimal.txt" as a table with 9 columns.
160163
// The first column is Time. The other columns correspond to the number of people with a certain infection state at this Time:
161164
// Time = Time in days, S = Susceptible, E = Exposed, I_NS = InfectedNoSymptoms, I_Sy = InfectedSymptoms, I_Sev = InfectedSevere,
162165
// I_Crit = InfectedCritical, R = Recovered, D = Dead
163-
sim.get_result().print_table({"S", "E", "I_NS", "I_Sy", "I_Sev", "I_Crit", "R", "D"}, 7, 4, outfile);
164-
166+
std::ofstream outfile("abm_minimal.txt");
167+
std::get<0>(historyTimeSeries.get_log())
168+
.print_table({"S", "E", "I_NS", "I_Sy", "I_Sev", "I_Crit", "R", "D"}, 7, 4, outfile);
165169
std::cout << "Results written to abm_minimal.txt" << std::endl;
166170

167171
return 0;

cpp/memilio/io/history.h

Lines changed: 96 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -17,75 +17,86 @@
1717
* See the License for the specific language governing permissions and
1818
* limitations under the License.
1919
*/
20-
#ifndef HISTORY_OBJ_H
21-
#define HISTORY_OBJ_H
20+
#ifndef MIO_IO_HISTORY_H
21+
#define MIO_IO_HISTORY_H
2222

23+
#include "memilio/utils/metaprogramming.h"
2324
#include <vector>
2425
#include <tuple>
2526
#include <iostream>
2627

2728
namespace mio
2829
{
2930

30-
namespace details
31-
{
32-
/*
33-
* @brief Helper function to get the index of a Type in a pack of Types at compile time.
34-
* @tparam T The Type that is searched for.
35-
* @tparam Types All Types in the pack of Types.
36-
* This function is used to get the index of a logger in a pack of loggers, e.g. index_templ_pack<Logger, Loggers...> gets the index of Logger in the pack Loggers.
37-
* Only for use in a Data Writer, not at runtime.
38-
*/
39-
template <typename T, typename U = void, typename... Types>
40-
constexpr size_t index_templ_pack()
41-
{
42-
return std::is_same<T, U>::value ? 0 : 1 + index_templ_pack<T, Types...>();
43-
}
44-
} // namespace details
45-
46-
struct LogOnce {
47-
};
48-
49-
struct LogAlways {
50-
};
51-
52-
/*
53-
* @brief This class writes data retrieved from loggers to memory. It can be used as the Writer template parameter for the History class.
54-
* @tparam Loggers The loggers that are used to log data.
55-
*/
31+
/**
32+
* @brief This class writes data retrieved from loggers to memory. It can be used as the Writer template parameter for the History class.
33+
* @tparam Loggers The loggers that are used to log data.
34+
*/
5635
template <class... Loggers>
5736
struct DataWriterToMemory {
5837
using Data = std::tuple<std::vector<typename Loggers::Type>...>;
38+
/**
39+
* @brief Adds a new record for a given log result t to data.
40+
* The parameter Logger is used to determine the type of the record t, as well as the data index at which
41+
* the record should be added to.
42+
* @param[in] t The result of Logger::log.
43+
* @param[in,out] data An instance of Data to add the record to.
44+
* @tparam Logger The type of the logger used to record t.
45+
*/
5946
template <class Logger>
60-
static void write_this(const typename Logger::Type& t, Data& data)
47+
static void add_record(const typename Logger::Type& t, Data& data)
6148
{
62-
std::get<details::index_templ_pack<Logger, Loggers...>()>(data).push_back(t);
49+
std::get<mio::index_of_type_v<Logger, Loggers...>>(data).push_back(t);
6350
}
6451
};
6552

66-
/*
67-
* @brief History class that handles writers and Loggers.
68-
* The class provides a log(T t) function to add the current record.
69-
* It provides a get_log() function to access the record.
70-
* It uses Loggers to retrieve data, and Writer to record it.
71-
* A Logger has a type "Type", a function "static Type log(const T&)" and is derived from either LogOnce or LogAlways.
72-
* LogOnce is only passed to Writer on the first call to History::log, LogAlways on all calls.
73-
* The Writer defines the type "Data" of the record, and defines with "static void write_this(const Logger::Type&, Data&)" how log values are added to it.
74-
* @tparam Writer The writer that is used to handle the data, e.g. store it into an array.
75-
* @tparam Loggers The loggers that are used to log data.
76-
*/
77-
78-
template <template <class...> class Writer, class... Loggers>
53+
/**
54+
* @brief History class that handles writers and loggers.
55+
* History provides a function "log" to add a new record and a function "get_log" to access all records.
56+
*
57+
* The History class uses Loggers to retrieve data from a given input, and a Writer to record this data.
58+
* A Logger is a struct with a type `Type` and functions `Type log(const T&)` and `bool should_log(const T&)`.
59+
* All Loggers must be unique types and default construcible/destructible. Their member "should_log" indicates whether
60+
* to log, while "Type" and "log" determine what is logged. The input for "should_log" and "log" is the same input
61+
* of type T that is given to "History::log". (Note: T does not have to be a template for a Logger implementation.)
62+
* The Writer defines the type `Data` to store all records (i.e. the return values of Logger::log), and the function
63+
* `template <class Logger> static void add_record(const Logger::Type&, Data&)` to add a new record. "add_record" is
64+
* used whenever "History::log" was called and "Logger::should_log" is true.
65+
*
66+
* @tparam Writer The writer that is used to handle the data, e.g. store it into an array.
67+
* @tparam Loggers The loggers that are used to log data.
68+
*/
69+
template <template <class... /*Loggers*/> class Writer, class... Loggers>
7970
class History
8071
{
72+
static_assert(!has_duplicates_v<Loggers...>, "The Loggers used for a History must be unique.");
73+
8174
public:
8275
using WriteWrapper = Writer<Loggers...>;
8376

77+
History() = default;
78+
79+
History(typename WriteWrapper::Data data)
80+
: m_data(data)
81+
{
82+
}
83+
84+
History(typename Loggers::Type... args)
85+
: m_data(std::tie(args...))
86+
{
87+
}
88+
89+
/**
90+
* @brief Logs new records according to the Writer and Loggers.
91+
*
92+
* Calls the log_impl function for every Logger for Input t to record data.
93+
* @tparam T The type of the record.
94+
* @param[in] t The input to record.
95+
*/
8496
template <class T>
8597
void log(const T& t)
8698
{
87-
log_impl<T, Loggers...>(t);
88-
logged = true;
99+
(log_impl(t, std::get<index_of_type_v<Loggers, Loggers...>>(m_loggers)), ...);
89100
}
90101

91102
/**
@@ -98,41 +109,59 @@ class History
98109
return m_data;
99110
}
100111

101-
template <class Logger>
102-
const std::vector<typename Logger::Type>& get_log()
103-
{
104-
return std::get<details::index_templ_pack<Logger, Loggers...>()>(m_data);
105-
}
106-
107112
private:
108113
typename WriteWrapper::Data m_data;
114+
std::tuple<Loggers...> m_loggers;
109115

110-
bool logged = false;
111-
112-
template <class T, class logger, class... loggers>
113-
std::enable_if_t<std::is_base_of<LogOnce, logger>::value> log_impl(const T& t)
116+
/**
117+
* @brief Checks if the given logger should log. If so, adds a record of the log to m_data.
118+
* @param[in] t The argument given to History::log. Passed to Logger::should_log and Logger::log.
119+
* @param[in] logger A Logger instance.
120+
* @tparam Logger A logger from the list Loggers.
121+
*/
122+
template <class T, class Logger>
123+
void log_impl(const T& t, Logger& logger)
114124
{
115-
if (!logged) {
116-
WriteWrapper::template write_this<logger>(logger::log(t), m_data);
125+
if (logger.should_log(t)) {
126+
WriteWrapper::template add_record<Logger>(logger.log(t), m_data);
117127
}
118-
log_impl<T, loggers...>(t);
119128
}
129+
};
120130

121-
template <class T, class logger, class... loggers>
122-
std::enable_if_t<std::is_base_of<LogAlways, logger>::value> log_impl(const T& t)
131+
template <class... Loggers>
132+
using HistoryWithMemoryWriter = History<DataWriterToMemory, Loggers...>;
133+
134+
/**
135+
* LogOnce and LogAlways can be used as a base class to write a logger for History. They each provide the function
136+
* `bool should_log(const T&)`, so that only the type `Type` and the function `Type log(const T&)` have to be
137+
* implemented for the derived logger, where T is some input type (the same type that is given to History::log).
138+
* LogOnce logs only on the first call to History::log, while LogAlways logs on every call.
139+
*
140+
* For any other logging behaviour, should_log has to be defined in the logger (no base class required).
141+
* @see History for a full list of requirements for a logger.
142+
* @{
143+
*/
144+
struct LogOnce {
145+
bool was_logged = false; ///< Remember if this Logger was logged already.
146+
147+
/// @brief For any type T, returns true on the first call only, and false thereafter.
148+
template <class T>
149+
bool should_log(const T&)
123150
{
124-
log_impl<T, loggers...>(t);
125-
WriteWrapper::template write_this<logger>(logger::log(t), m_data);
151+
return was_logged ? false : (was_logged = true);
126152
}
153+
};
127154

155+
struct LogAlways {
156+
/// @brief Always returns true, for any type T.
128157
template <class T>
129-
void log_impl(const T&)
158+
constexpr bool should_log(const T&)
130159
{
160+
return true;
131161
}
132162
};
133-
134-
template <class... Loggers>
135-
using HistoryWithMemoryWriter = History<DataWriterToMemory, Loggers...>;
163+
/** @} */
136164

137165
} // namespace mio
138-
#endif //HISTORY_OBJ_H
166+
167+
#endif // MIO_IO_HISTORY_H

cpp/memilio/utils/metaprogramming.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,38 @@ struct index_of_type {
246246
template <class Type, class... Types>
247247
constexpr std::size_t index_of_type_v = index_of_type<Type, Types...>::value;
248248

249+
/**
250+
* Tests whether the list Types contains any type multiple times.
251+
* @tparam Types A list of types.
252+
*/
253+
template <class... Types>
254+
struct has_duplicates {
255+
private:
256+
/**
257+
* @brief Checks if Types has a duplicate entry using an index sequence.
258+
* @tparam Indices Exactly the list '0, ... , sizeof...(Types) - 1'. Use std::make_index_sequence.
259+
* @return True if Types contains a duplicate type, false otherwise.
260+
*/
261+
template <std::size_t... Indices>
262+
static constexpr bool has_duplicates_impl(std::index_sequence<Indices...>)
263+
{
264+
// index_of_type_v will always be equal to the index of the first occurance of a type,
265+
// while Indices contains its actual index. Hence, if there is any mismatch, then there is a duplicate.
266+
return ((index_of_type_v<Types, Types...> != Indices) || ...);
267+
}
268+
269+
public:
270+
static constexpr bool value = has_duplicates_impl(std::make_index_sequence<sizeof...(Types)>{});
271+
};
272+
273+
/**
274+
* @brief Checks whether Type has any duplicates.
275+
* Equivalent to has_duplicates<Types...>::value.
276+
* @see is_type_in_list.
277+
*/
278+
template <class... Types>
279+
constexpr bool has_duplicates_v = has_duplicates<Types...>::value;
280+
249281
} // namespace mio
250282

251283
#endif

cpp/models/abm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_library(abm
2727
vaccine.h
2828
mask.h
2929
mask.cpp
30+
common_abm_loggers.h
3031
)
3132
target_link_libraries(abm PUBLIC memilio)
3233
target_include_directories(abm PUBLIC

0 commit comments

Comments
 (0)