Skip to content

Commit 1b251d9

Browse files
authored
1250 enhance testing logic in abm (#1276)
- Enhance logic, readability and access time - Simplify: dont update activity status -> error prone and not better runtime
1 parent b738898 commit 1b251d9

14 files changed

Lines changed: 575 additions & 331 deletions

File tree

cpp/benchmarks/abm.cpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,9 @@ mio::abm::Simulation<> make_simulation(size_t num_persons, std::initializer_list
111111
return mio::abm::TestingCriteria(random_ages, random_states);
112112
};
113113

114-
model.get_testing_strategy().add_testing_scheme(
115-
mio::abm::LocationType::School,
116-
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
117-
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
118-
model.get_testing_strategy().add_testing_scheme(
119-
mio::abm::LocationType::Work,
120-
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
121-
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
122-
model.get_testing_strategy().add_testing_scheme(
123-
mio::abm::LocationType::Home,
124-
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
125-
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
126-
model.get_testing_strategy().add_testing_scheme(
127-
mio::abm::LocationType::SocialEvent,
114+
model.get_testing_strategy().add_scheme(
115+
{mio::abm::LocationType::School, mio::abm::LocationType::Work, mio::abm::LocationType::SocialEvent,
116+
mio::abm::LocationType::Home},
128117
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
129118
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
130119

cpp/examples/abm_history_object.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ int main()
137137
auto testing_criteria_work = mio::abm::TestingCriteria();
138138
auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, validity_period, start_date, end_date,
139139
test_parameters, probability);
140-
model.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme_work);
140+
model.get_testing_strategy().add_scheme(mio::abm::LocationType::Work, testing_scheme_work);
141141

142142
// Assign infection state to each person.
143143
// The infection states are chosen randomly.

cpp/examples/abm_minimal.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ int main()
113113
auto test_parameters = model.parameters.get<mio::abm::TestData>()[test_type];
114114
auto testing_criteria_work = mio::abm::TestingCriteria();
115115
auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, validity_period, start_date, end_date,
116-
test_parameters, probability);
117-
model.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme_work);
116+
test_parameters, probability);
117+
model.get_testing_strategy().add_scheme(mio::abm::LocationType::Work, testing_scheme_work);
118118

119119
// Assign infection state to each person.
120120
// The infection states are chosen randomly with the following distribution

cpp/models/abm/model.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,14 @@ void Model::perform_mobility(TimePoint t, TimeSpan dt)
123123
get_number_persons(target_location.get_id()) >= target_location.get_capacity().persons) {
124124
return false;
125125
}
126-
// the Person cannot move if the performed TestingStrategy is positive
127-
if (!m_testing_strategy.run_strategy(personal_rng, person, target_location, t)) {
126+
// The person cannot move if he has a positive test result, except he want to go to a hospital, ICU or home.
127+
if (!m_testing_strategy.run_and_check(personal_rng, person, target_location, t) &&
128+
target_location.get_type() != LocationType::Hospital &&
129+
target_location.get_type() != LocationType::ICU &&
130+
target_location.get_type() != LocationType::Home) {
128131
return false;
129132
}
133+
130134
// update worn mask to target location's requirements
131135
if (target_location.is_mask_required()) {
132136
// if the current MaskProtection level is lower than required, the Person changes mask
@@ -189,7 +193,9 @@ void Model::perform_mobility(TimePoint t, TimeSpan dt)
189193
continue;
190194
}
191195
// skip the trip if the performed TestingStrategy is positive
192-
if (!m_testing_strategy.run_strategy(personal_rng, person, target_location, t)) {
196+
if (!m_testing_strategy.run_and_check(personal_rng, person, target_location, t) &&
197+
target_location.get_type() != LocationType::Hospital && target_location.get_type() != LocationType::ICU &&
198+
target_location.get_type() != LocationType::Home) {
193199
continue;
194200
}
195201
// all requirements are met, move to target location
@@ -296,8 +302,6 @@ void Model::compute_exposure_caches(TimePoint t, TimeSpan dt)
296302

297303
void Model::begin_step(TimePoint t, TimeSpan dt)
298304
{
299-
m_testing_strategy.update_activity_status(t);
300-
301305
if (!m_is_local_population_cache_valid) {
302306
build_compute_local_population_cache();
303307
m_is_local_population_cache_valid = true;

cpp/models/abm/model.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -341,22 +341,6 @@ class Model
341341
return m_id;
342342
}
343343

344-
/**
345-
* @brief Add a TestingScheme to the set of schemes that are checked for testing at all Locations that have
346-
* the LocationType.
347-
* @param[in] loc_type LocationId key for TestingScheme to be added.
348-
* @param[in] scheme TestingScheme to be added.
349-
*/
350-
void add_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme);
351-
352-
/**
353-
* @brief Remove a TestingScheme from the set of schemes that are checked for testing at all Locations that have
354-
* the LocationType.
355-
* @param[in] loc_type LocationId key for TestingScheme to be added.
356-
* @param[in] scheme TestingScheme to be added.
357-
*/
358-
void remove_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme);
359-
360344
/**
361345
* @brief Get a reference to a Person from this Model.
362346
* @param[in] person_id A Person's PersonId.

cpp/models/abm/testing_strategy.cpp

Lines changed: 59 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,6 @@ bool TestingCriteria::operator==(const TestingCriteria& other) const
4343
return m_ages == other.m_ages && m_infection_states == other.m_infection_states;
4444
}
4545

46-
void TestingCriteria::add_age_group(const AgeGroup age_group)
47-
{
48-
m_ages.set(static_cast<size_t>(age_group), true);
49-
}
50-
51-
void TestingCriteria::remove_age_group(const AgeGroup age_group)
52-
{
53-
m_ages.set(static_cast<size_t>(age_group), false);
54-
}
55-
56-
void TestingCriteria::add_infection_state(const InfectionState infection_state)
57-
{
58-
m_infection_states.set(static_cast<size_t>(infection_state), true);
59-
}
60-
61-
void TestingCriteria::remove_infection_state(const InfectionState infection_state)
62-
{
63-
m_infection_states.set(static_cast<size_t>(infection_state), false);
64-
}
65-
6646
bool TestingCriteria::evaluate(const Person& p, TimePoint t) const
6747
{
6848
// An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property is set.
@@ -79,6 +59,7 @@ TestingScheme::TestingScheme(const TestingCriteria& testing_criteria, TimeSpan v
7959
, m_test_parameters(test_parameters)
8060
, m_probability(probability)
8161
{
62+
assert(start_date <= end_date && "Start date must be before or equal to end date");
8263
}
8364

8465
bool TestingScheme::operator==(const TestingScheme& other) const
@@ -91,122 +72,104 @@ bool TestingScheme::operator==(const TestingScheme& other) const
9172
//To be adjusted and also TestType should be static.
9273
}
9374

94-
bool TestingScheme::is_active() const
75+
bool TestingScheme::is_active(TimePoint t) const
9576
{
96-
return m_is_active;
77+
return (m_start_date <= t && t < m_end_date);
9778
}
9879

99-
void TestingScheme::update_activity_status(TimePoint t)
80+
bool TestingScheme::run_and_test(PersonalRandomNumberGenerator& rng, Person& person, TimePoint t) const
10081
{
101-
m_is_active = (m_start_date <= t && t <= m_end_date);
102-
}
82+
if (!is_active(t)) { // If the scheme is not active, do nothing; early return
83+
return false;
84+
}
10385

104-
bool TestingScheme::run_scheme(PersonalRandomNumberGenerator& rng, Person& person, TimePoint t) const
105-
{
10686
auto test_result = person.get_test_result(m_test_parameters.type);
10787
// If the agent has a test result valid until now, use the result directly
10888
if ((test_result.time_of_testing > TimePoint(std::numeric_limits<int>::min())) &&
10989
(test_result.time_of_testing + m_validity_period >= t)) {
110-
return !test_result.result;
90+
return test_result.result; // If the test is positive, the entry is not allowed, and vice versa
91+
}
92+
if (person.get_compliance(InterventionType::Testing) <
93+
1.0 && // Dont need to draw a random number if the person is compliant either way
94+
!person.is_compliant(
95+
rng, InterventionType::Testing)) { // If the person is not compliant with the testing intervention
96+
return true;
11197
}
11298
// Otherwise, the time_of_testing in the past (i.e. the agent has already performed it).
11399
if (m_testing_criteria.evaluate(person, t - m_test_parameters.required_time)) {
114100
double random = UniformDistribution<double>::get_instance()(rng);
115101
if (random < m_probability) {
116102
bool result = person.get_tested(rng, t - m_test_parameters.required_time, m_test_parameters);
117103
person.add_test_result(t, m_test_parameters.type, result);
118-
return !result;
104+
return result; // If the test is positive, the entry is not allowed, and vice versa
119105
}
120106
}
121-
return true;
107+
// If the test is not performed, the entry is allowed
108+
return false;
122109
}
123110

124-
TestingStrategy::TestingStrategy(const std::vector<LocalStrategy>& location_to_schemes_map)
125-
: m_location_to_schemes_map(location_to_schemes_map.begin(), location_to_schemes_map.end())
111+
TestingStrategy::TestingStrategy(const std::vector<LocalStrategy>& location_to_schemes_id,
112+
const std::vector<LocalStrategy>& location_to_schemes_type)
113+
: m_testing_schemes_at_location_id(location_to_schemes_id.begin(), location_to_schemes_id.end())
114+
, m_testing_schemes_at_location_type(location_to_schemes_type.begin(), location_to_schemes_type.end())
126115
{
127116
}
128117

129-
void TestingStrategy::add_testing_scheme(const LocationType& loc_type, const LocationId& loc_id,
130-
const TestingScheme& scheme)
118+
void TestingStrategy::add_scheme(const LocationId& loc_id, const TestingScheme& scheme)
131119
{
132-
auto iter_schemes =
133-
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [&](const auto& p) {
134-
return p.type == loc_type && p.id == loc_id;
135-
});
136-
if (iter_schemes == m_location_to_schemes_map.end()) {
137-
//no schemes for this location yet, add a new list with one scheme
138-
m_location_to_schemes_map.push_back({loc_type, loc_id, std::vector<TestingScheme>(1, scheme)});
139-
}
140-
else {
141-
//add scheme to existing vector if the scheme doesn't exist yet
142-
auto& schemes = iter_schemes->schemes;
143-
if (std::find(schemes.begin(), schemes.end(), scheme) == schemes.end()) {
144-
schemes.push_back(scheme);
145-
}
120+
if (loc_id.get() >= m_testing_schemes_at_location_id.size()) {
121+
m_testing_schemes_at_location_id.resize(loc_id.get() + 1);
146122
}
123+
m_testing_schemes_at_location_id[loc_id.get()].schemes.push_back(scheme);
147124
}
148125

149-
void TestingStrategy::remove_testing_scheme(const LocationType& loc_type, const LocationId& loc_id,
150-
const TestingScheme& scheme)
126+
void TestingStrategy::add_scheme(const LocationType& loc_type, const TestingScheme& scheme)
151127
{
152-
auto iter_schemes =
153-
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [&](const auto& p) {
154-
return p.type == loc_type && p.id == loc_id;
155-
});
156-
if (iter_schemes != m_location_to_schemes_map.end()) {
157-
//remove the scheme from the list
158-
auto& schemes_vector = iter_schemes->schemes;
159-
auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme);
160-
schemes_vector.erase(last, schemes_vector.end());
161-
//delete the list of schemes for this location if no schemes left
162-
if (schemes_vector.empty()) {
163-
m_location_to_schemes_map.erase(iter_schemes);
164-
}
128+
if ((size_t)loc_type >= m_testing_schemes_at_location_type.size()) {
129+
m_testing_schemes_at_location_type.resize((size_t)loc_type + 1);
165130
}
131+
m_testing_schemes_at_location_type[(size_t)loc_type].schemes.push_back(scheme);
166132
}
167133

168-
void TestingStrategy::update_activity_status(TimePoint t)
134+
bool TestingStrategy::run_and_check(PersonalRandomNumberGenerator& rng, Person& person, const Location& location,
135+
TimePoint t)
169136
{
170-
for (auto& [_type, _id, testing_schemes] : m_location_to_schemes_map) {
171-
for (auto& scheme : testing_schemes) {
172-
scheme.update_activity_status(t);
173-
}
174-
}
175-
}
137+
// Early return if no scheme defined for this location or type
138+
auto loc_id = location.get_id().get();
139+
auto loc_type = static_cast<size_t>(location.get_type());
176140

177-
bool TestingStrategy::run_strategy(PersonalRandomNumberGenerator& rng, Person& person, const Location& location,
178-
TimePoint t)
179-
{
180-
// A Person is always allowed to go home and this is never called if a person is not discharged from a hospital or ICU.
181-
if (location.get_type() == mio::abm::LocationType::Home) {
182-
return true;
141+
bool has_id_schemes =
142+
loc_id < m_testing_schemes_at_location_id.size() && !m_testing_schemes_at_location_id[loc_id].schemes.empty();
143+
144+
bool has_type_schemes = loc_type < m_testing_schemes_at_location_type.size() &&
145+
!m_testing_schemes_at_location_type[loc_type].schemes.empty();
146+
147+
if (!has_id_schemes && !has_type_schemes) {
148+
return true; // No applicable schemes
183149
}
184150

185-
// If the Person does not comply to Testing where there is a testing scheme at the target location, it is not allowed to enter.
186-
if (!person.is_compliant(rng, InterventionType::Testing)) {
187-
return false;
151+
bool entry_allowed = true; // Assume entry is allowed unless a scheme denies it
152+
// Check schemes for specific location id
153+
if (has_id_schemes) {
154+
for (const auto& scheme : m_testing_schemes_at_location_id[loc_id].schemes) {
155+
if (scheme.run_and_test(rng, person, t)) {
156+
entry_allowed = false; // Deny entry
157+
}
158+
}
188159
}
189160

190-
// Lookup schemes for this specific location as well as the location type
191-
// Lookup in std::vector instead of std::map should be much faster unless for large numbers of schemes
192-
for (auto key : {std::make_pair(location.get_type(), location.get_id()),
193-
std::make_pair(location.get_type(), LocationId::invalid_id())}) {
194-
auto iter_schemes =
195-
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [&](const auto& p) {
196-
return p.type == key.first && p.id == key.second;
197-
});
198-
if (iter_schemes != m_location_to_schemes_map.end()) {
199-
// Apply all testing schemes that are found
200-
auto& schemes = iter_schemes->schemes;
201-
// Whether the Person is allowed to enter or not depends on the test result(s).
202-
if (!std::all_of(schemes.begin(), schemes.end(), [&rng, &person, t](TestingScheme& ts) {
203-
return !ts.is_active() || ts.run_scheme(rng, person, t);
204-
})) {
205-
return false;
161+
// Check schemes for location type
162+
if (has_type_schemes) {
163+
for (const auto& scheme : m_testing_schemes_at_location_type[loc_type].schemes) {
164+
if (scheme.run_and_test(rng, person, t)) {
165+
entry_allowed = false; // Deny entry
206166
}
207167
}
208168
}
209-
return true;
169+
170+
// If the location is a home, entry is always allowed regardless of testing, no early return here because we still need to test
171+
172+
return entry_allowed;
210173
}
211174

212175
} // namespace abm

0 commit comments

Comments
 (0)