Skip to content

Commit f0abb25

Browse files
HenrZucharlie0614
andauthored
1315 support IO functionality with single age group for ODE SECIR (#1354)
- When inititialize the ODE SECIR model with real data, we were restricted to a model with 6 age groups (using the official reported data) - Now, we can also input a model with a single age group with correct initialization. - Short Fix for readthedocs build. Co-authored-by: Carlotta Gerstein <[email protected]>
1 parent 97b2d65 commit f0abb25

3 files changed

Lines changed: 252 additions & 21 deletions

File tree

cpp/models/ode_secir/parameters_io.h

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,17 @@ IOResult<void> set_confirmed_cases_data(std::vector<Model<FP>>& model, std::vect
7575
const std::vector<double>& scaling_factor_inf)
7676
{
7777
const size_t num_age_groups = ConfirmedCasesDataEntry::age_group_names.size();
78-
assert(scaling_factor_inf.size() == num_age_groups);
78+
// allow single scalar scaling that is broadcast to all age groups
79+
assert(scaling_factor_inf.size() == 1 || scaling_factor_inf.size() == num_age_groups);
80+
81+
// Set scaling factors to match num age groups
82+
std::vector<double> scaling_factor_inf_full;
83+
if (scaling_factor_inf.size() == 1) {
84+
scaling_factor_inf_full.assign(num_age_groups, scaling_factor_inf[0]);
85+
}
86+
else {
87+
scaling_factor_inf_full = scaling_factor_inf;
88+
}
7989

8090
std::vector<std::vector<int>> t_InfectedNoSymptoms{model.size()};
8191
std::vector<std::vector<int>> t_Exposed{model.size()};
@@ -89,7 +99,12 @@ IOResult<void> set_confirmed_cases_data(std::vector<Model<FP>>& model, std::vect
8999
std::vector<std::vector<double>> mu_U_D{model.size()};
90100

91101
for (size_t node = 0; node < model.size(); ++node) {
92-
for (size_t group = 0; group < num_age_groups; group++) {
102+
const size_t model_groups = (size_t)model[node].parameters.get_num_groups();
103+
assert(model_groups == 1 || model_groups == num_age_groups);
104+
for (size_t ag = 0; ag < num_age_groups; ag++) {
105+
// If the model has fewer groups than casedata entries available,
106+
// reuse group 0 parameters for all RKI age groups
107+
const size_t group = (model_groups == num_age_groups) ? ag : 0;
93108

94109
t_Exposed[node].push_back(
95110
static_cast<int>(std::round(model[node].parameters.template get<TimeExposed<FP>>()[(AgeGroup)group])));
@@ -121,26 +136,49 @@ IOResult<void> set_confirmed_cases_data(std::vector<Model<FP>>& model, std::vect
121136
BOOST_OUTCOME_TRY(read_confirmed_cases_data(case_data, region, date, num_Exposed, num_InfectedNoSymptoms,
122137
num_InfectedSymptoms, num_InfectedSevere, num_icu, num_death, num_rec,
123138
t_Exposed, t_InfectedNoSymptoms, t_InfectedSymptoms, t_InfectedSevere,
124-
t_InfectedCritical, mu_C_R, mu_I_H, mu_H_U, scaling_factor_inf));
139+
t_InfectedCritical, mu_C_R, mu_I_H, mu_H_U, scaling_factor_inf_full));
125140

126141
for (size_t node = 0; node < model.size(); node++) {
127142
if (std::accumulate(num_InfectedSymptoms[node].begin(), num_InfectedSymptoms[node].end(), 0.0) > 0) {
128143
size_t num_groups = (size_t)model[node].parameters.get_num_groups();
129-
for (size_t i = 0; i < num_groups; i++) {
130-
model[node].populations[{AgeGroup(i), InfectionState::Exposed}] = num_Exposed[node][i];
131-
model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptoms}] =
132-
num_InfectedNoSymptoms[node][i];
133-
model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptomsConfirmed}] = 0;
134-
model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptoms}] =
135-
num_InfectedSymptoms[node][i];
136-
model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptomsConfirmed}] = 0;
137-
model[node].populations[{AgeGroup(i), InfectionState::InfectedSevere}] = num_InfectedSevere[node][i];
138-
// Only set the number of ICU patients here, if the date is not available in the data.
144+
if (num_groups == num_age_groups) {
145+
for (size_t i = 0; i < num_groups; i++) {
146+
model[node].populations[{AgeGroup(i), InfectionState::Exposed}] = num_Exposed[node][i];
147+
model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptoms}] =
148+
num_InfectedNoSymptoms[node][i];
149+
model[node].populations[{AgeGroup(i), InfectionState::InfectedNoSymptomsConfirmed}] = 0;
150+
model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptoms}] =
151+
num_InfectedSymptoms[node][i];
152+
model[node].populations[{AgeGroup(i), InfectionState::InfectedSymptomsConfirmed}] = 0;
153+
model[node].populations[{AgeGroup(i), InfectionState::InfectedSevere}] =
154+
num_InfectedSevere[node][i];
155+
// Only set the number of ICU patients here, if the date is not available in the data.
156+
if (!is_divi_data_available(date)) {
157+
model[node].populations[{AgeGroup(i), InfectionState::InfectedCritical}] = num_icu[node][i];
158+
}
159+
model[node].populations[{AgeGroup(i), InfectionState::Dead}] = num_death[node][i];
160+
model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i];
161+
}
162+
}
163+
else {
164+
const auto sum_vec = [](const std::vector<double>& v) {
165+
return std::accumulate(v.begin(), v.end(), 0.0);
166+
};
167+
const size_t i0 = 0;
168+
model[node].populations[{AgeGroup(i0), InfectionState::Exposed}] = sum_vec(num_Exposed[node]);
169+
model[node].populations[{AgeGroup(i0), InfectionState::InfectedNoSymptoms}] =
170+
sum_vec(num_InfectedNoSymptoms[node]);
171+
model[node].populations[{AgeGroup(i0), InfectionState::InfectedNoSymptomsConfirmed}] = 0;
172+
model[node].populations[{AgeGroup(i0), InfectionState::InfectedSymptoms}] =
173+
sum_vec(num_InfectedSymptoms[node]);
174+
model[node].populations[{AgeGroup(i0), InfectionState::InfectedSymptomsConfirmed}] = 0;
175+
model[node].populations[{AgeGroup(i0), InfectionState::InfectedSevere}] =
176+
sum_vec(num_InfectedSevere[node]);
139177
if (!is_divi_data_available(date)) {
140-
model[node].populations[{AgeGroup(i), InfectionState::InfectedCritical}] = num_icu[node][i];
178+
model[node].populations[{AgeGroup(i0), InfectionState::InfectedCritical}] = sum_vec(num_icu[node]);
141179
}
142-
model[node].populations[{AgeGroup(i), InfectionState::Dead}] = num_death[node][i];
143-
model[node].populations[{AgeGroup(i), InfectionState::Recovered}] = num_rec[node][i];
180+
model[node].populations[{AgeGroup(i0), InfectionState::Dead}] = sum_vec(num_death[node]);
181+
model[node].populations[{AgeGroup(i0), InfectionState::Recovered}] = sum_vec(num_rec[node]);
144182
}
145183
}
146184
else {
@@ -231,10 +269,20 @@ IOResult<void> set_population_data(std::vector<Model<FP>>& model,
231269
assert(num_population.size() == vregion.size());
232270
assert(model.size() == vregion.size());
233271
for (size_t region = 0; region < vregion.size(); region++) {
234-
auto num_groups = model[region].parameters.get_num_groups();
235-
for (auto i = AgeGroup(0); i < num_groups; i++) {
272+
const auto model_groups = (size_t)model[region].parameters.get_num_groups();
273+
const auto data_groups = num_population[region].size();
274+
assert(data_groups == model_groups || (model_groups == 1 && data_groups >= 1));
275+
276+
if (data_groups == model_groups) {
277+
for (auto i = AgeGroup(0); i < model[region].parameters.get_num_groups(); i++) {
278+
model[region].populations.template set_difference_from_group_total<AgeGroup>(
279+
{i, InfectionState::Susceptible}, num_population[region][(size_t)i]);
280+
}
281+
}
282+
else if (model_groups == 1 && data_groups >= 1) {
283+
const double total = std::accumulate(num_population[region].begin(), num_population[region].end(), 0.0);
236284
model[region].populations.template set_difference_from_group_total<AgeGroup>(
237-
{i, InfectionState::Susceptible}, num_population[region][size_t(i)]);
285+
{AgeGroup(0), InfectionState::Susceptible}, total);
238286
}
239287
}
240288
return success();
@@ -283,8 +331,8 @@ IOResult<void> export_input_data_county_timeseries(
283331
const std::string& divi_data_path, const std::string& confirmed_cases_path, const std::string& population_data_path)
284332
{
285333
const auto num_age_groups = (size_t)models[0].parameters.get_num_groups();
286-
assert(scaling_factor_inf.size() == num_age_groups);
287-
assert(num_age_groups == ConfirmedCasesDataEntry::age_group_names.size());
334+
// allow scalar scaling factor as convenience for 1-group models
335+
assert(scaling_factor_inf.size() == 1 || scaling_factor_inf.size() == num_age_groups);
288336
assert(models.size() == region.size());
289337
std::vector<TimeSeries<double>> extrapolated_data(
290338
region.size(), TimeSeries<double>::zero(num_days + 1, (size_t)InfectionState::Count * num_age_groups));

cpp/tests/test_odesecir.cpp

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "memilio/io/parameters_io.h"
3030
#include "memilio/data/analyze_result.h"
3131
#include "memilio/math/adapt_rk.h"
32+
#include "memilio/geography/regions.h"
3233

3334
#include <gtest/gtest.h>
3435

@@ -1516,5 +1517,184 @@ TEST(TestOdeSecir, read_population_data_failure)
15161517
EXPECT_EQ(result.error().message(), "File with county population expected.");
15171518
}
15181519

1520+
TEST(TestOdeSecirIO, read_input_data_county_aggregates_one_group)
1521+
{
1522+
// Set up two models with different number of age groups.
1523+
const size_t num_age_groups = 6;
1524+
std::vector<mio::osecir::Model<double>> models6{mio::osecir::Model<double>((int)num_age_groups)};
1525+
std::vector<mio::osecir::Model<double>> models1{mio::osecir::Model<double>(1)};
1526+
1527+
// Relevant parameters for model with 6 age groups
1528+
for (auto i = mio::AgeGroup(0); i < (mio::AgeGroup)num_age_groups; ++i) {
1529+
models6[0].parameters.get<mio::osecir::SeverePerInfectedSymptoms<double>>()[i] = 0.2;
1530+
models6[0].parameters.get<mio::osecir::CriticalPerSevere<double>>()[i] = 0.25;
1531+
}
1532+
1533+
// Relevant parameters for model with 1 age group
1534+
models1[0].parameters.get<mio::osecir::SeverePerInfectedSymptoms<double>>()[mio::AgeGroup(0)] = 0.2;
1535+
models1[0].parameters.get<mio::osecir::CriticalPerSevere<double>>()[mio::AgeGroup(0)] = 0.25;
1536+
1537+
const auto pydata_dir_Germany = mio::path_join(TEST_DATA_DIR, "Germany", "pydata");
1538+
const std::vector<int> counties{1002};
1539+
const auto date = mio::Date(2020, 12, 1);
1540+
1541+
std::vector<double> scale6(num_age_groups, 1.0);
1542+
std::vector<double> scale1{1.0};
1543+
1544+
// Initialize both models
1545+
ASSERT_THAT(mio::osecir::read_input_data_county(models6, date, counties, scale6, 1.0, pydata_dir_Germany),
1546+
IsSuccess());
1547+
ASSERT_THAT(mio::osecir::read_input_data_county(models1, date, counties, scale1, 1.0, pydata_dir_Germany),
1548+
IsSuccess());
1549+
1550+
// Aggreagate the results from the model with 6 age groups and compare with the model with 1 age group
1551+
const auto& m6 = models6[0];
1552+
const auto& m1 = models1[0];
1553+
const double tol = 1e-10;
1554+
for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) {
1555+
double sum6 = 0.0;
1556+
for (size_t ag = 0; ag < num_age_groups; ++ag) {
1557+
sum6 += m6.populations[{mio::AgeGroup(ag), (mio::osecir::InfectionState)s}].value();
1558+
}
1559+
const double v1 = m1.populations[{mio::AgeGroup(0), (mio::osecir::InfectionState)s}].value();
1560+
EXPECT_NEAR(sum6, v1, tol);
1561+
}
1562+
1563+
// Total population
1564+
EXPECT_NEAR(m6.populations.get_total(), m1.populations.get_total(), tol);
1565+
}
1566+
1567+
TEST(TestOdeSecirIO, set_population_data_single_age_group)
1568+
{
1569+
const size_t num_age_groups = 6;
1570+
1571+
// Create two models: one with 6 age groups, one with 1 age group
1572+
std::vector<mio::osecir::Model<double>> models6{mio::osecir::Model<double>((int)num_age_groups)};
1573+
std::vector<mio::osecir::Model<double>> models1{mio::osecir::Model<double>(1)};
1574+
1575+
// Test population data with 6 different values for age groups
1576+
std::vector<std::vector<double>> population_data6 = {{10000.0, 20000.0, 30000.0, 25000.0, 15000.0, 8000.0}};
1577+
std::vector<std::vector<double>> population_data1 = {{108000.0}}; // sum of all age groups
1578+
std::vector<int> regions = {1002};
1579+
1580+
// Set population data for both models
1581+
EXPECT_THAT(mio::osecir::details::set_population_data(models6, population_data6, regions), IsSuccess());
1582+
EXPECT_THAT(mio::osecir::details::set_population_data(models1, population_data1, regions), IsSuccess());
1583+
1584+
// Sum all compartments across age groups in 6-group model and compare 1-group model
1585+
const double tol = 1e-10;
1586+
for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) {
1587+
double sum6 = 0.0;
1588+
for (size_t ag = 0; ag < num_age_groups; ++ag) {
1589+
sum6 += models6[0].populations[{mio::AgeGroup(ag), (mio::osecir::InfectionState)s}].value();
1590+
}
1591+
double val1 = models1[0].populations[{mio::AgeGroup(0), (mio::osecir::InfectionState)s}].value();
1592+
1593+
EXPECT_NEAR(sum6, val1, tol);
1594+
}
1595+
1596+
// Total population should also match
1597+
EXPECT_NEAR(models6[0].populations.get_total(), models1[0].populations.get_total(), tol);
1598+
}
1599+
1600+
TEST(TestOdeSecirIO, set_confirmed_cases_data_single_age_group)
1601+
{
1602+
const size_t num_age_groups = 6;
1603+
1604+
// Create two models: one with 6 age groups, one with 1 age group
1605+
std::vector<mio::osecir::Model<double>> models6{mio::osecir::Model<double>((int)num_age_groups)};
1606+
std::vector<mio::osecir::Model<double>> models1{mio::osecir::Model<double>(1)};
1607+
1608+
// Create case data for all 6 age groups over multiple days (current day + 6 days back)
1609+
std::vector<mio::ConfirmedCasesDataEntry> case_data;
1610+
1611+
for (int day_offset = -6; day_offset <= 0; ++day_offset) {
1612+
mio::Date current_date = mio::offset_date_by_days(mio::Date(2020, 12, 1), day_offset);
1613+
1614+
for (int age_group = 0; age_group < 6; ++age_group) {
1615+
double base_confirmed = 80.0 + age_group * 8.0 + (day_offset + 6) * 5.0;
1616+
double base_recovered = 40.0 + age_group * 4.0 + (day_offset + 6) * 3.0;
1617+
double base_deaths = 3.0 + age_group * 0.5 + (day_offset + 6) * 0.5;
1618+
1619+
mio::ConfirmedCasesDataEntry entry{base_confirmed,
1620+
base_recovered,
1621+
base_deaths,
1622+
current_date,
1623+
mio::AgeGroup(age_group),
1624+
{},
1625+
mio::regions::CountyId(1002),
1626+
{}};
1627+
case_data.push_back(entry);
1628+
}
1629+
}
1630+
1631+
std::vector<int> regions = {1002};
1632+
std::vector<double> scaling_factors = {1.0};
1633+
1634+
// Set confirmed cases data for both models
1635+
EXPECT_THAT(mio::osecir::details::set_confirmed_cases_data(models6, case_data, regions, mio::Date(2020, 12, 1),
1636+
scaling_factors),
1637+
IsSuccess());
1638+
EXPECT_THAT(mio::osecir::details::set_confirmed_cases_data(models1, case_data, regions, mio::Date(2020, 12, 1),
1639+
scaling_factors),
1640+
IsSuccess());
1641+
1642+
// Sum all compartments across age groups in 6-group model should be equal to 1-group model
1643+
for (int s = 0; s < (int)mio::osecir::InfectionState::Count; ++s) {
1644+
double sum6 = 0.0;
1645+
for (size_t ag = 0; ag < num_age_groups; ++ag) {
1646+
sum6 += models6[0].populations[{mio::AgeGroup(ag), (mio::osecir::InfectionState)s}].value();
1647+
}
1648+
1649+
double val1 = models1[0].populations[{mio::AgeGroup(0), (mio::osecir::InfectionState)s}].value();
1650+
1651+
EXPECT_NEAR(sum6, val1, 1e-10);
1652+
}
1653+
1654+
// Total population
1655+
EXPECT_NEAR(models6[0].populations.get_total(), models1[0].populations.get_total(), 1e-10);
1656+
}
1657+
1658+
TEST(TestOdeSecirIO, set_divi_data_single_age_group)
1659+
{
1660+
// Create models with 6 age groups and 1 age group
1661+
std::vector<mio::osecir::Model<double>> models_6_groups{mio::osecir::Model<double>(6)};
1662+
std::vector<mio::osecir::Model<double>> models_1_group{mio::osecir::Model<double>(1)};
1663+
1664+
// Set relevant parameters for all age groups
1665+
for (int i = 0; i < 6; i++) {
1666+
models_6_groups[0].parameters.get<mio::osecir::SeverePerInfectedSymptoms<double>>()[mio::AgeGroup(i)] = 0.2;
1667+
models_6_groups[0].parameters.get<mio::osecir::CriticalPerSevere<double>>()[mio::AgeGroup(i)] = 0.25;
1668+
}
1669+
1670+
// Set relevant parameters for 1 age group model
1671+
models_1_group[0].parameters.get<mio::osecir::SeverePerInfectedSymptoms<double>>()[mio::AgeGroup(0)] = 0.2;
1672+
models_1_group[0].parameters.get<mio::osecir::CriticalPerSevere<double>>()[mio::AgeGroup(0)] = 0.25;
1673+
1674+
// Apply DIVI data to both models
1675+
std::vector<int> regions = {1002};
1676+
double scaling_factor_icu = 1.0;
1677+
mio::Date date(2020, 12, 1);
1678+
std::string divi_data_path = mio::path_join(TEST_DATA_DIR, "Germany", "pydata", "county_divi_ma7.json");
1679+
auto result_6_groups =
1680+
mio::osecir::details::set_divi_data(models_6_groups, divi_data_path, regions, date, scaling_factor_icu);
1681+
auto result_1_group =
1682+
mio::osecir::details::set_divi_data(models_1_group, divi_data_path, regions, date, scaling_factor_icu);
1683+
1684+
EXPECT_THAT(result_6_groups, IsSuccess());
1685+
EXPECT_THAT(result_1_group, IsSuccess());
1686+
1687+
// Calculate totals after applying DIVI data
1688+
double total_icu_6_groups_after = 0.0;
1689+
for (int i = 0; i < 6; i++) {
1690+
total_icu_6_groups_after +=
1691+
models_6_groups[0].populations[{mio::AgeGroup(i), mio::osecir::InfectionState::InfectedCritical}].value();
1692+
}
1693+
double icu_1_group_after =
1694+
models_1_group[0].populations[{mio::AgeGroup(0), mio::osecir::InfectionState::InfectedCritical}].value();
1695+
1696+
EXPECT_NEAR(total_icu_6_groups_after, icu_1_group_after, 1e-10);
1697+
}
1698+
15191699
#endif
15201700
#endif

docs/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
cmake>=3.26
2+
ninja
3+
scikit-build>=0.18
14
sphinx==7.1.2
25
sphinx-rtd-theme==1.3.0rc1
36
sphinx-copybutton

0 commit comments

Comments
 (0)