Skip to content

Commit ee41c39

Browse files
Replace contact rates of ABM by ContactMatrix (#1516)
Co-authored-by: jubicker <[email protected]>
1 parent 77e510d commit ee41c39

7 files changed

Lines changed: 30 additions & 26 deletions

File tree

cpp/examples/abm_minimal.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,8 @@ int main()
9999
// Increase aerosol transmission for all locations
100100
model.parameters.get<mio::abm::AerosolTransmissionRates>() = 10.0;
101101
// Increase contact rate for all people between 15 and 34 (i.e. people meet more often in the same location)
102-
model.get_location(work)
103-
.get_infection_parameters()
104-
.get<mio::abm::ContactRates>()[{age_group_15_to_34, age_group_15_to_34}] = 10.0;
102+
model.get_location(work).get_infection_parameters().get<mio::abm::ContactRates>().get_baseline()(
103+
age_group_15_to_34.get(), age_group_15_to_34.get()) = 10.0;
105104

106105
// People can get tested at work (and do this with 0.5 probability) from time point 0 to day 10.
107106
auto validity_period = mio::abm::days(1);

cpp/examples/abm_parameter_study.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,8 @@ mio::abm::Model make_model(const mio::RandomNumberGenerator& rng)
111111
// Increase aerosol transmission for all locations
112112
model.parameters.get<mio::abm::AerosolTransmissionRates>() = 10.0;
113113
// Increase contact rate for all people between 15 and 34 (i.e. people meet more often in the same location)
114-
model.get_location(work)
115-
.get_infection_parameters()
116-
.get<mio::abm::ContactRates>()[{age_group_15_to_34, age_group_15_to_34}] = 10.0;
114+
model.get_location(work).get_infection_parameters().get<mio::abm::ContactRates>().get_baseline()(
115+
age_group_15_to_34.get(), age_group_15_to_34.get()) = 10.0;
117116

118117
// People can get tested at work (and do this with 0.5 probability) from time point 0 to day 10.
119118
auto validity_period = mio::abm::days(1);

cpp/models/abm/location.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,6 @@ class Location
261261
}
262262

263263
private:
264-
friend DefaultFactory<Location>;
265-
Location() = default;
266-
267264
LocationType m_type; ///< Type of the Location.
268265
LocationId m_id; ///< Unique identifier for the Location in the Model owning it.
269266
LocalInfectionParameters m_parameters; ///< Infection parameters for the Location.
@@ -273,8 +270,16 @@ class Location
273270
m_geographical_location; ///< Geographical location (longitude and latitude) of the Location.
274271
int m_model_id; ///< Model id the location is in. Only used for ABM graph model or hybrid graph model.
275272
};
276-
277273
} // namespace abm
274+
275+
/// @brief Creates an instance of abm::Location for default serialization.
276+
template <>
277+
struct DefaultFactory<abm::Location> {
278+
static abm::Location create()
279+
{
280+
return abm::Location(abm::LocationType::Count, abm::LocationId::invalid_id());
281+
}
282+
};
278283
} // namespace mio
279284

280285
#endif

cpp/models/abm/model_functions.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ ScalarType total_exposure_by_contacts(const ContactExposureRates& rates, const C
4343
age_receiver_group_size > 1) // adjust for the person not meeting themself
4444
{
4545
total_exposure += rates[{cell_index, virus, age_transmitter}] *
46-
params.get<ContactRates>()[{age_receiver, age_transmitter}] * age_receiver_group_size /
47-
(age_receiver_group_size - 1);
46+
params.get<ContactRates>().get_baseline()(age_receiver.get(), age_transmitter.get()) *
47+
age_receiver_group_size / (age_receiver_group_size - 1);
4848
}
4949
else {
5050
total_exposure += rates[{cell_index, virus, age_transmitter}] *
51-
params.get<ContactRates>()[{age_receiver, age_transmitter}];
51+
params.get<ContactRates>().get_baseline()(age_receiver.get(), age_transmitter.get());
5252
}
5353
}
5454
return total_exposure;
@@ -189,12 +189,15 @@ void adjust_contact_rates(Location& location, size_t num_agegroups)
189189
ScalarType total_contacts = 0.;
190190
// slizing would be preferred but is problematic since both Tags of ContactRates are AgeGroup
191191
for (auto contact_to = AgeGroup(0); contact_to < AgeGroup(num_agegroups); contact_to++) {
192-
total_contacts += location.get_infection_parameters().get<ContactRates>()[{contact_from, contact_to}];
192+
total_contacts += location.get_infection_parameters().get<ContactRates>().get_baseline()(contact_from.get(),
193+
contact_to.get());
193194
}
194195
if (total_contacts > location.get_infection_parameters().get<MaximumContacts>()) {
195196
for (auto contact_to = AgeGroup(0); contact_to < AgeGroup(num_agegroups); contact_to++) {
196-
location.get_infection_parameters().get<ContactRates>()[{contact_from, contact_to}] =
197-
location.get_infection_parameters().get<ContactRates>()[{contact_from, contact_to}] *
197+
location.get_infection_parameters().get<ContactRates>().get_baseline()(contact_from.get(),
198+
contact_to.get()) =
199+
location.get_infection_parameters().get<ContactRates>().get_baseline()(contact_from.get(),
200+
contact_to.get()) *
198201
location.get_infection_parameters().get<MaximumContacts>() / total_contacts;
199202
}
200203
}

cpp/models/abm/parameters.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -727,11 +727,11 @@ struct MaximumContacts {
727727
* contact rates
728728
*/
729729
struct ContactRates {
730-
using Type = CustomIndexArray<ScalarType, AgeGroup, AgeGroup>;
730+
using Type = ContactMatrix<ScalarType>;
731731
static Type get_default(AgeGroup size)
732732
{
733-
return Type({size, size},
734-
1.0); // amount of contacts from AgeGroup a to AgeGroup b per day
733+
return Type(
734+
static_cast<Eigen::Index>((size_t)size)); // amount of contacts from AgeGroup a to AgeGroup b per day
735735
}
736736
static std::string name()
737737
{

cpp/tests/test_abm_location.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,9 @@ TEST_F(TestLocation, adjustContactRates)
178178
{
179179
mio::abm::Location loc(mio::abm::LocationType::SocialEvent, mio::abm::LocationId(0));
180180
//Set the maximum contacts smaller than the contact rates
181-
loc.get_infection_parameters().get<mio::abm::MaximumContacts>() = 2;
182-
loc.get_infection_parameters().get<mio::abm::ContactRates>()[{mio::AgeGroup(0), mio::AgeGroup(0)}] = 4;
181+
loc.get_infection_parameters().get<mio::abm::MaximumContacts>() = 2;
182+
loc.get_infection_parameters().get<mio::abm::ContactRates>().get_baseline()(0, 0) = 4;
183183
mio::abm::adjust_contact_rates(loc, 1);
184-
auto adjusted_contacts_rate =
185-
loc.get_infection_parameters().get<mio::abm::ContactRates>()[{mio::AgeGroup(0), mio::AgeGroup(0)}];
184+
auto adjusted_contacts_rate = loc.get_infection_parameters().get<mio::abm::ContactRates>().get_baseline()(0, 0);
186185
EXPECT_EQ(adjusted_contacts_rate, 2);
187186
}

docs/source/cpp/mobility_based_abm.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,8 @@ We can also set the contact rates for specific age groups at a location:
204204

205205
.. code-block:: cpp
206206
207-
model.get_location(work)
208-
.get_infection_parameters()
209-
.get<mio::abm::ContactRates>()[{age_group_15_to_34, age_group_15_to_34}] = 10.0;
207+
model.get_location(work).get_infection_parameters().get<mio::abm::ContactRates>().get_baseline()(
208+
age_group_15_to_34.get(), age_group_15_to_34.get()) = 10.0;
210209
211210
For a full list of parameters, see `here <https://memilio.readthedocs.io/en/latest/api/file__home_docs_checkouts_readthedocs.org_user_builds_memilio_checkouts_latest_cpp_models_abm_parameters.h.html>`_.
212211

0 commit comments

Comments
 (0)