Skip to content

Commit 9142b19

Browse files
authored
756 ABM benchmark (#828)
Add ABM benchmark, needs to be added to CI later.
1 parent 7cdbf18 commit 9142b19

5 files changed

Lines changed: 204 additions & 0 deletions

File tree

cpp/benchmarks/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,6 @@ add_executable(simulation_benchmark simulation.cpp)
5050
target_link_libraries(simulation_benchmark PRIVATE memilio ode_secir benchmark::benchmark)
5151
add_executable(graph_simulation_benchmark graph_simulation.cpp)
5252
target_link_libraries(graph_simulation_benchmark PRIVATE memilio ode_secirvvs benchmark::benchmark)
53+
54+
add_executable(abm_benchmark abm.cpp)
55+
target_link_libraries(abm_benchmark PRIVATE abm benchmark::benchmark)

cpp/benchmarks/abm.cpp

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#include "abm/simulation.h"
2+
#include "memilio/utils/stl_util.h"
3+
#include "benchmark/benchmark.h"
4+
5+
mio::abm::Simulation make_simulation(size_t num_persons, std::initializer_list<uint32_t> seeds)
6+
{
7+
auto rng = mio::RandomNumberGenerator();
8+
rng.seed(seeds);
9+
auto world = mio::abm::World(5);
10+
world.get_rng() = rng;
11+
12+
//create persons at home
13+
const auto mean_home_size = 5.0;
14+
const auto min_home_size = 1;
15+
auto& home_size_distribution = mio::PoissonDistribution<int>::get_instance();
16+
auto home = world.add_location(mio::abm::LocationType::Home);
17+
auto planned_home_size = home_size_distribution(world.get_rng(), mean_home_size);
18+
auto home_size = 0;
19+
for (size_t i = 0; i < num_persons; ++i) {
20+
if (home_size >= std::max(min_home_size, planned_home_size)) {
21+
home = world.add_location(mio::abm::LocationType::Home);
22+
planned_home_size = home_size_distribution(world.get_rng(), mean_home_size);
23+
home_size = 0;
24+
}
25+
26+
auto age = mio::AgeGroup(mio::UniformIntDistribution<size_t>::get_instance()(
27+
world.get_rng(), size_t(0), world.parameters.get_num_groups() - 1));
28+
auto& person = world.add_person(home, age);
29+
person.set_assigned_location(home);
30+
home_size++;
31+
}
32+
33+
//create other locations
34+
for (auto loc_type :
35+
{mio::abm::LocationType::School, mio::abm::LocationType::Work, mio::abm::LocationType::SocialEvent,
36+
mio::abm::LocationType::BasicsShop, mio::abm::LocationType::Hospital, mio::abm::LocationType::ICU}) {
37+
38+
const auto num_locs = std::max(size_t(1), num_persons / 2'000);
39+
std::vector<mio::abm::LocationId> locs(num_locs);
40+
std::generate(locs.begin(), locs.end(), [&] {
41+
return world.add_location(loc_type);
42+
});
43+
for (auto& person : world.get_persons()) {
44+
auto loc_idx =
45+
mio::UniformIntDistribution<size_t>::get_instance()(world.get_rng(), size_t(0), num_locs - 1);
46+
person.set_assigned_location(locs[loc_idx]);
47+
}
48+
}
49+
50+
//infections and masks
51+
for (auto& person : world.get_persons()) {
52+
auto prng = mio::abm::Person::RandomNumberGenerator(world.get_rng(), person);
53+
//some % of people are infected, large enough to have some infection activity without everyone dying
54+
auto pct_infected = 0.05;
55+
if (mio::UniformDistribution<double>::get_instance()(prng, 0.0, 1.0) < pct_infected) {
56+
auto state = mio::abm::InfectionState(
57+
mio::UniformIntDistribution<int>::get_instance()(prng, 1, int(mio::abm::InfectionState::Count) - 1));
58+
auto infection =
59+
mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, person.get_age(),
60+
world.parameters, mio::abm::TimePoint(0), state);
61+
person.add_new_infection(std::move(infection));
62+
}
63+
64+
//equal chance of (moderate) mask refusal and (moderate) mask eagerness
65+
auto pct_mask_values = std::array{0.05 /*-1*/, 0.2 /*-0.5*/, 0.5 /*0*/, 0.2 /*0.5*/, 0.05 /*1*/};
66+
auto mask_value = -1 + 0.5 * mio::DiscreteDistribution<int>::get_instance()(prng, pct_mask_values);
67+
person.set_mask_preferences({size_t(mio::abm::LocationType::Count), mask_value});
68+
}
69+
70+
//masks at locations
71+
for (auto& loc : world.get_locations())
72+
{
73+
//some % of locations require masks
74+
//skip homes so persons always have a place to go, simulation might break otherwise
75+
auto pct_require_mask = 0.2;
76+
auto requires_mask = loc.get_type() != mio::abm::LocationType::Home &&
77+
mio::UniformDistribution<double>::get_instance()(world.get_rng()) < pct_require_mask;
78+
loc.set_npi_active(requires_mask);
79+
}
80+
81+
//testing schemes
82+
auto sample = [&](auto v, size_t n) { //selects n elements from list v
83+
std::shuffle(v.begin(), v.end(), world.get_rng());
84+
return std::vector<typename decltype(v)::value_type>(v.begin(), v.begin() + n);
85+
};
86+
std::vector<mio::AgeGroup> ages;
87+
std::generate_n(std::back_inserter(ages), world.parameters.get_num_groups(), [a = 0]() mutable {
88+
return mio::AgeGroup(a++);
89+
});
90+
auto random_criteria = [&]() {
91+
auto random_ages = sample(ages, 2);
92+
auto random_states = std::vector<mio::abm::InfectionState>(0);
93+
return mio::abm::TestingCriteria(random_ages, random_states);
94+
};
95+
96+
world.get_testing_strategy().add_testing_scheme(
97+
mio::abm::LocationType::School,
98+
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
99+
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
100+
world.get_testing_strategy().add_testing_scheme(
101+
mio::abm::LocationType::Work,
102+
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
103+
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
104+
world.get_testing_strategy().add_testing_scheme(
105+
mio::abm::LocationType::Home,
106+
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
107+
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
108+
world.get_testing_strategy().add_testing_scheme(
109+
mio::abm::LocationType::SocialEvent,
110+
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
111+
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
112+
113+
return mio::abm::Simulation(mio::abm::TimePoint(0), std::move(world));
114+
}
115+
116+
/**
117+
* Benchmark for the ABM simulation.
118+
* @param num_persons Number of persons in the simulation.
119+
* @param seeds Seeds for the random number generator.
120+
*/
121+
void abm_benchmark(benchmark::State& state, size_t num_persons, std::initializer_list<uint32_t> seeds)
122+
{
123+
mio::set_log_level(mio::LogLevel::warn);
124+
125+
for (auto&& _ : state) {
126+
state.PauseTiming(); //exclude the setup from the benchmark
127+
auto sim = make_simulation(num_persons, seeds);
128+
state.ResumeTiming();
129+
130+
//simulated time should be long enough to have full infection runs and migration to every location
131+
auto final_time = sim.get_time() + mio::abm::days(10);
132+
sim.advance(final_time);
133+
134+
//debug output can be enabled to check for unexpected results (e.g. infections dieing out)
135+
//normally should have no significant effect on runtime
136+
const bool monitor_infection_activity = false;
137+
if constexpr (monitor_infection_activity) {
138+
std::cout << "num_persons = " << num_persons << "\n";
139+
std::cout << sim.get_result()[0].transpose() << "\n";
140+
std::cout << sim.get_result().get_last_value().transpose() << "\n";
141+
}
142+
}
143+
}
144+
145+
//Measure ABM simulation run time with different sizes and different seeds.
146+
//Fixed RNG seeds to make runs comparable. When there are code changes, the simulation will still
147+
//run differently due to different sequence of random numbers being drawn. But for large enough sizes
148+
//RNG should average out, so runs should be comparable even with code changes.
149+
//We run a few different benchmarks to hopefully catch abnormal cases. Then seeds may
150+
//have to be adjusted to get the benchmark back to normal.
151+
//For small sizes (e.g. 10k) extreme cases are too likely, i.e. infections die out
152+
//or overwhelm everything, so we don't benchmark these. Results should be mostly transferrable.
153+
BENCHMARK_CAPTURE(abm_benchmark, abm_benchmark_50k, 50000, {14159265u, 35897932u})->Unit(benchmark::kMillisecond);
154+
BENCHMARK_CAPTURE(abm_benchmark, abm_benchmark_100k, 100000, {38462643u, 38327950u})->Unit(benchmark::kMillisecond);
155+
BENCHMARK_CAPTURE(abm_benchmark, abm_benchmark_200k, 200000, {28841971u, 69399375u})->Unit(benchmark::kMillisecond);
156+
157+
BENCHMARK_MAIN();

cpp/memilio/utils/random_number_generator.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,13 @@ using UniformIntDistribution = DistributionAdapter<std::uniform_int_distribution
666666
template <class Real>
667667
using UniformDistribution = DistributionAdapter<std::uniform_real_distribution<Real>>;
668668

669+
/**
670+
* adapted poisson_distribution.
671+
* @see DistributionAdapter
672+
*/
673+
template <class Int>
674+
using PoissonDistribution = DistributionAdapter<std::poisson_distribution<Int>>;
675+
669676
} // namespace mio
670677

671678
#endif

cpp/memilio/utils/stl_util.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
#include "memilio/utils/metaprogramming.h"
2424

25+
#include <array>
26+
#include <numeric>
2527
#include <vector>
2628
#include <algorithm>
2729
#include <utility>
@@ -289,6 +291,30 @@ bool contains(Iter b, Iter e, Pred p)
289291
return find_if(b, e, p) != e;
290292
}
291293

294+
/**
295+
* Get an std::array that contains all members of an enum class.
296+
* The enum class must be a valid index, i.e. members must be sequential starting at 0 and there must
297+
* be a member `Count` at the end, that will not be included in the array.
298+
* Example:
299+
* ```
300+
* enum class E { A, B, Count };
301+
* assert(enum_members<E>() == std::array<2, E>(E::A, E::B));
302+
* ```
303+
* @tparam T An enum class that is a valid index.
304+
* @return Array of all members of the enum class not including T::Count.
305+
*/
306+
template<class T>
307+
constexpr std::array<T, size_t(T::Count)> enum_members()
308+
{
309+
auto enum_members = std::array<T, size_t(T::Count)>{};
310+
auto indices = std::array<std::underlying_type_t<T>, size_t(T::Count)>{};
311+
std::iota(indices.begin(), indices.end(), std::underlying_type_t<T>(0));
312+
std::transform(indices.begin(), indices.end(), enum_members.begin(), [](auto i) {
313+
return T(i);
314+
});
315+
return enum_members;
316+
}
317+
292318
} // namespace mio
293319

294320
#endif //STL_UTIL_H

cpp/tests/test_stl_util.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,17 @@ TEST(TestContains, empty)
243243
}));
244244
}
245245

246+
TEST(EnumMembers, works)
247+
{
248+
enum class E
249+
{
250+
A,
251+
B,
252+
Count
253+
};
254+
ASSERT_THAT(mio::enum_members<E>(), testing::ElementsAre(E::A, E::B));
255+
}
256+
246257
TEST(TestContains, set_ostream_format)
247258
{
248259
std::ostringstream output;

0 commit comments

Comments
 (0)