Skip to content

Commit 1aafc48

Browse files
mhhegerHenrZu
andauthored
1293 update secir groups surrogate models (#1297)
- rework the surrogate models for the age resolved ode secir model - New utils folder to avoid doubling code for basic surrogate functionalities - New handling of dampings in the surrogate model Co-authored-by: Henrik Zunker <[email protected]>
1 parent 3a0d789 commit 1aafc48

14 files changed

Lines changed: 1047 additions & 426 deletions

File tree

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SECIR model with multiple age groups and one damping
1+
# SECIR model with multiple age groups and multiple dampings
22

3-
This model is an application of the SECIR model implemented in https://github.com/DLR-SC/memilio/tree/main/cpp/models/ode_secir/ stratified by age groups using one damping to represent a change in the contact matrice.
4-
The example is based on https://github.com/DLR-SC/memilio/tree/main/pycode/examples/simulation/secir_groups.py which uses python bindings to run the underlying C++ code.
3+
This model is an application of the SECIR model implemented in https://github.com/DLR-SC/memilio/tree/main/cpp/models/ode_secir/ stratified by age groups using dampings to represent changes in the contact matrix.
4+
The example is based on https://github.com/DLR-SC/memilio/tree/main/pycode/examples/simulation/ode_secir_groups.py which uses python bindings to run the underlying C++ code.

pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_groups/data_generation.py

Lines changed: 86 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# See the License for the specific language governing permissions and
1818
# limitations under the License.
1919
#############################################################################
20+
2021
import copy
2122
import os
2223
import pickle
@@ -33,75 +34,41 @@
3334
from memilio.simulation.osecir import (Index_InfectionState,
3435
InfectionState, Model,
3536
interpolate_simulation_result, simulate)
37+
import memilio.surrogatemodel.utils.dampings as dampings
38+
from memilio.surrogatemodel.utils.helper_functions import (
39+
interpolate_age_groups, remove_confirmed_compartments, normalize_simulation_data)
40+
import memilio.simulation as mio
41+
import memilio.simulation.osecir as osecir
3642

3743

38-
def interpolate_age_groups(data_entry):
39-
""" Interpolates the age groups from the population data into the age groups used in the simulation.
40-
We assume that the people in the age groups are uniformly distributed.
41-
42-
:param data_entry: Data entry containing the population data.
43-
:returns: List containing the population in each age group used in the simulation.
44-
45-
"""
46-
age_groups = {
47-
"A00-A04": data_entry['<3 years'] + data_entry['3-5 years'] * 2 / 3,
48-
"A05-A14": data_entry['3-5 years'] * 1 / 3 + data_entry['6-14 years'],
49-
"A15-A34": data_entry['15-17 years'] + data_entry['18-24 years'] + data_entry['25-29 years'] + data_entry['30-39 years'] * 1 / 2,
50-
"A35-A59": data_entry['30-39 years'] * 1 / 2 + data_entry['40-49 years'] + data_entry['50-64 years'] * 2 / 3,
51-
"A60-A79": data_entry['50-64 years'] * 1 / 3 + data_entry['65-74 years'] + data_entry['>74 years'] * 1 / 5,
52-
"A80+": data_entry['>74 years'] * 4 / 5
53-
}
54-
return [age_groups[key] for key in age_groups]
55-
56-
57-
def remove_confirmed_compartments(result_array):
58-
""" Removes the confirmed compartments which are not used in the data generation.
59-
60-
:param result_array: Array containing the simulation results.
61-
:returns: Array containing the simulation results without the confirmed compartments.
62-
63-
"""
64-
num_groups = int(result_array.shape[1] / 10)
65-
delete_indices = [index for i in range(
66-
num_groups) for index in (3+10*i, 5+10*i)]
67-
return np.delete(result_array, delete_indices, axis=1)
68-
69-
70-
def transform_data(data, transformer, num_runs):
71-
""" Transforms the data by a logarithmic normalization.
72-
Reshaping is necessary, because the transformer needs an array with dimension <= 2.
73-
74-
:param data: Data to be transformed.
75-
:param transformer: Transformer used for the transformation.
76-
:param num_runs:
77-
:returns: Transformed data.
78-
79-
"""
80-
data = np.asarray(data).transpose(2, 0, 1).reshape(48, -1)
81-
scaled_data = transformer.transform(data)
82-
return tf.convert_to_tensor(scaled_data.transpose().reshape(num_runs, -1, 48))
83-
84-
85-
def run_secir_groups_simulation(days, damping_day, populations):
44+
def run_secir_groups_simulation(days, damping_days, damping_factors, populations):
8645
""" Uses an ODE SECIR model allowing for asymptomatic infection with 6 different age groups. The model is not stratified by region.
8746
Virus-specific parameters are fixed and initial number of persons in the particular infection states are chosen randomly from defined ranges.
8847
8948
:param days: Describes how many days we simulate within a single run.
90-
:param damping_day: The day when damping is applied.
49+
:param damping_days: The days when damping is applied.
50+
:param damping_factors: damping factors associated to the damping days.
9151
:param populations: List containing the population in each age group.
92-
:returns: List containing the populations in each compartment used to initialize the run.
52+
:returns: Tuple of lists (list_of_simulation_results, list_of_damped_matrices), the first containing the simulation results, the second list containing the
53+
damped contact matrices.
9354
9455
"""
56+
# Collect indices of confirmed compartments
57+
del_indices = []
58+
59+
if len(damping_days) != len(damping_factors):
60+
raise ValueError("Length of damping_days and damping_factors differ!")
61+
9562
set_log_level(LogLevel.Off)
9663

9764
start_day = 1
9865
start_month = 1
9966
start_year = 2019
10067
dt = 0.1
10168

102-
# Define age Groups
103-
groups = ['0-4', '5-14', '15-34', '35-59', '60-79', '80+']
104-
num_groups = len(groups)
69+
age_groups = ['0-4', '5-14', '15-34', '35-59', '60-79', '80+']
70+
# Get number of age groups
71+
num_groups = len(age_groups)
10572

10673
# Initialize Parameters
10774
model = Model(num_groups)
@@ -118,26 +85,26 @@ def run_secir_groups_simulation(days, damping_day, populations):
11885
# Initial number of people in each compartment with random numbers
11986
model.populations[AgeGroup(i), Index_InfectionState(
12087
InfectionState.Exposed)] = random.uniform(
121-
0.00025, 0.0005) * populations[i]
88+
0.00025, 0.005) * populations[i]
12289
model.populations[AgeGroup(i), Index_InfectionState(
12390
InfectionState.InfectedNoSymptoms)] = random.uniform(
124-
0.0001, 0.00035) * populations[i]
91+
0.0001, 0.0035) * populations[i]
12592
model.populations[AgeGroup(i), Index_InfectionState(
12693
InfectionState.InfectedNoSymptomsConfirmed)] = 0
12794
model.populations[AgeGroup(i), Index_InfectionState(
12895
InfectionState.InfectedSymptoms)] = random.uniform(
129-
0.00007, 0.0001) * populations[i]
96+
0.00007, 0.001) * populations[i]
13097
model.populations[AgeGroup(i), Index_InfectionState(
13198
InfectionState.InfectedSymptomsConfirmed)] = 0
13299
model.populations[AgeGroup(i), Index_InfectionState(
133100
InfectionState.InfectedSevere)] = random.uniform(
134-
0.00003, 0.00006) * populations[i]
101+
0.00003, 0.0006) * populations[i]
135102
model.populations[AgeGroup(i), Index_InfectionState(
136103
InfectionState.InfectedCritical)] = random.uniform(
137-
0.00001, 0.00002) * populations[i]
104+
0.00001, 0.0002) * populations[i]
138105
model.populations[AgeGroup(i), Index_InfectionState(
139106
InfectionState.Recovered)] = random.uniform(
140-
0.002, 0.008) * populations[i]
107+
0.002, 0.08) * populations[i]
141108
model.populations[AgeGroup(i),
142109
Index_InfectionState(InfectionState.Dead)] = 0
143110
model.populations.set_difference_from_group_total_AgeGroup(
@@ -154,6 +121,14 @@ def run_secir_groups_simulation(days, damping_day, populations):
154121
# twice the value of RiskOfInfectionFromSymptomatic
155122
model.parameters.MaxRiskOfInfectionFromSymptomatic[AgeGroup(i)] = 0.5
156123

124+
# Collecting deletable indices
125+
index_no_sym_conf = model.populations.get_flat_index(
126+
osecir.MultiIndex_PopulationsArray(mio.AgeGroup(i), osecir.InfectionState.InfectedNoSymptomsConfirmed))
127+
index_sym_conf = model.populations.get_flat_index(
128+
osecir.MultiIndex_PopulationsArray(mio.AgeGroup(i), osecir.InfectionState.InfectedSymptomsConfirmed))
129+
del_indices.append(index_no_sym_conf)
130+
del_indices.append(index_sym_conf)
131+
157132
# StartDay is the n-th day of the year
158133
model.parameters.StartDay = (
159134
date(start_year, start_month, start_day) - date(start_year, 1, 1)).days
@@ -166,14 +141,16 @@ def run_secir_groups_simulation(days, damping_day, populations):
166141
model.parameters.ContactPatterns.cont_freq_mat[0].minimum = minimum
167142

168143
# Generate a damping matrix and assign it to the model
169-
damping = np.ones((num_groups, num_groups)
170-
) * np.float16(random.uniform(0, 0.5))
171-
172-
model.parameters.ContactPatterns.cont_freq_mat.add_damping(Damping(
173-
coeffs=(damping), t=damping_day, level=0, type=0))
144+
damped_matrices = []
174145

175-
damped_contact_matrix = model.parameters.ContactPatterns.cont_freq_mat.get_matrix_at(
176-
damping_day+1)
146+
for i in np.arange(len(damping_days)):
147+
damping = np.ones((num_groups, num_groups)
148+
) * damping_factors[i]
149+
day = damping_days[i]
150+
model.parameters.ContactPatterns.cont_freq_mat.add_damping(Damping(
151+
coeffs=(damping), t=day, level=0, type=0))
152+
damped_matrices.append(model.parameters.ContactPatterns.cont_freq_mat.get_matrix_at(
153+
day+1))
177154

178155
# Apply mathematical constraints to parameters
179156
model.apply_constraints()
@@ -184,21 +161,22 @@ def run_secir_groups_simulation(days, damping_day, populations):
184161
# Interpolate simulation result on days time scale
185162
result = interpolate_simulation_result(result)
186163

164+
# Omit first column, as the time points are not of interest here.
187165
result_array = remove_confirmed_compartments(
188-
np.transpose(result.as_ndarray()[1:, :]))
166+
np.transpose(result.as_ndarray()[1:, :]), del_indices)
189167

190-
# Omit first column, as the time points are not of interest here.
191168
dataset_entries = copy.deepcopy(result_array)
169+
dataset_entries = np.nan_to_num(dataset_entries)
192170

193-
return dataset_entries.tolist(), damped_contact_matrix
171+
return dataset_entries.tolist(), damped_matrices
194172

195173

196174
def generate_data(
197175
num_runs, path_out, path_population, input_width, label_width,
198-
normalize=True, save_data=True):
199-
""" Generate data sets of num_runs many equation-based model simulations and transforms the computed results by a log(1+x) transformation.
176+
normalize=True, save_data=True, damping_method="random", number_dampings=5):
177+
""" Generate data sets of num_runs many equation-based model simulations and possibly transforms the computed results by a log(1+x) transformation.
200178
Divides the results in input and label data sets and returns them as a dictionary of two TensorFlow Stacks.
201-
In general, we have 8 different compartments and 6 age groups. If we choose,
179+
In general, we have 8 different compartments and 6 age groups. If we choose
202180
input_width = 5 and label_width = 20, the dataset has
203181
- input with dimension 5 x 8 x 6
204182
- labels with dimension 20 x 8 x 6
@@ -210,14 +188,17 @@ def generate_data(
210188
:param label_width: Int value that defines the size of the labels.
211189
:param normalize: Default: true Option to transform dataset by logarithmic normalization.
212190
:param save_data: Default: true Option to save the dataset.
191+
:param damping_method: String specifying the damping method, that should be used. Possible values "classic", "active", "random".
192+
:param number_dampings: Maximal number of possible dampings.
213193
:returns: Data dictionary of input and label data sets.
214194
215195
"""
216196
data = {
217197
"inputs": [],
218198
"labels": [],
219-
"contact_matrix": [],
220-
"damping_day": []
199+
"contact_matrices": [],
200+
"damping_factors": [],
201+
"damping_days": []
221202
}
222203

223204
# The number of days is the same as the sum of input and label width.
@@ -232,16 +213,18 @@ def generate_data(
232213
bar = Bar('Number of Runs done', max=num_runs)
233214
for _ in range(0, num_runs):
234215

235-
# Generate a random damping day
236-
damping_day = random.randrange(
237-
input_width, input_width+label_width)
216+
# Generate random damping days
217+
damping_days, damping_factors = dampings.generate_dampings(
218+
days, number_dampings, method=damping_method, min_distance=2,
219+
min_damping_day=2)
238220

239-
data_run, damped_contact_matrix = run_secir_groups_simulation(
240-
days, damping_day, population[random.randint(0, len(population) - 1)])
221+
data_run, damped_matrices = run_secir_groups_simulation(
222+
days, damping_days, damping_factors, population[random.randint(0, len(population) - 1)])
241223
data['inputs'].append(data_run[:input_width])
242224
data['labels'].append(data_run[input_width:])
243-
data['contact_matrix'].append(np.array(damped_contact_matrix))
244-
data['damping_day'].append([damping_day])
225+
data['contact_matrices'].append(damped_matrices)
226+
data['damping_factors'].append(damping_factors)
227+
data['damping_days'].append(damping_days)
245228
bar.next()
246229
bar.finish()
247230

@@ -250,8 +233,10 @@ def generate_data(
250233
transformer = FunctionTransformer(np.log1p, validate=True)
251234

252235
# transform inputs and labels
253-
data['inputs'] = transform_data(data['inputs'], transformer, num_runs)
254-
data['labels'] = transform_data(data['labels'], transformer, num_runs)
236+
data['inputs'] = normalize_simulation_data(
237+
data['inputs'], transformer, num_runs)
238+
data['labels'] = normalize_simulation_data(
239+
data['labels'], transformer, num_runs)
255240
else:
256241
data['inputs'] = tf.convert_to_tensor(data['inputs'])
257242
data['labels'] = tf.convert_to_tensor(data['labels'])
@@ -261,8 +246,15 @@ def generate_data(
261246
if not os.path.isdir(path_out):
262247
os.mkdir(path_out)
263248

264-
# save dict to json file
265-
with open(os.path.join(path_out, 'data_secir_groups.pickle'), 'wb') as f:
249+
# save dict to pickle file
250+
if num_runs < 1000:
251+
filename = 'data_secir_groups_%ddays_%d_' % (
252+
label_width, num_runs) + damping_method+'.pickle'
253+
else:
254+
filename = 'data_secir_groups_%ddays_%dk_' % (
255+
label_width, num_runs//1000) + damping_method+'.pickle'
256+
257+
with open(os.path.join(path_out, filename), 'wb') as f:
266258
pickle.dump(data, f)
267259
return data
268260

@@ -291,13 +283,13 @@ def getMinimumMatrix():
291283
""" loads the minimum matrix"""
292284

293285
minimum_contact_matrix0 = os.path.join(
294-
"./data/contacts/minimum_home.txt")
286+
"./data/Germany/contacts/minimum_home.txt")
295287
minimum_contact_matrix1 = os.path.join(
296-
"./data/contacts/minimum_school_pf_eig.txt")
288+
"./data/Germany/contacts/minimum_school_pf_eig.txt")
297289
minimum_contact_matrix2 = os.path.join(
298-
"./data/contacts/minimum_work.txt")
290+
"./data/Germany/contacts/minimum_work.txt")
299291
minimum_contact_matrix3 = os.path.join(
300-
"./data/contacts/minimum_other.txt")
292+
"./data/Germany/contacts/minimum_other.txt")
301293

302294
minimum = np.loadtxt(minimum_contact_matrix0) \
303295
+ np.loadtxt(minimum_contact_matrix1) + \
@@ -310,7 +302,8 @@ def getMinimumMatrix():
310302
def get_population(path):
311303
""" read population data in list from dataset
312304
313-
:param path: Path to the dataset containing the population data
305+
:param path: Path to the dataset containing the population
306+
:returns: List of interpolated age grouped population data
314307
315308
"""
316309

@@ -332,7 +325,7 @@ def get_population(path):
332325
r"data//Germany//pydata//county_current_population.json")
333326

334327
input_width = 5
335-
label_width = 30
336-
num_runs = 10000
328+
label_width = 90
329+
num_runs = 100
337330
data = generate_data(num_runs, path_output, path_population, input_width,
338-
label_width)
331+
label_width, damping_method="active")

0 commit comments

Comments
 (0)