Skip to content

Commit ffe5e55

Browse files
HenrZuAgathaSchmidtmknaranja
authored
396 First surrogate model with own structure (#456)
Implementation of some simple neural net surrogate models to learn simple equation-based models' dynamic. --------- Co-authored-by: AgathaSchmidt <[email protected]> Co-authored-by: Martin J. Kühn <[email protected]>
1 parent 363c39d commit ffe5e55

File tree

15 files changed

+918
-0
lines changed

15 files changed

+918
-0
lines changed

.github/actions/test-py/action.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ runs:
2222
with:
2323
name: python-wheels-${{ inputs.package }}
2424
path: pycode/wheelhouse
25+
- name: Download Python Wheels for surrogatemodel
26+
uses: actions/download-artifact@v3
27+
if: inputs.package == 'surrogatemodel'
28+
with:
29+
name: python-wheels-simulation
30+
path: pycode/wheelhouse
2531
- name: Set up Python 3.8
2632
uses: actions/setup-python@v4
2733
with:

.github/workflows/main.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,27 @@ jobs:
191191
with:
192192
package: simulation
193193

194+
build-py-surrogatemodel:
195+
if: github.event.pull_request.draft == false
196+
runs-on: ubuntu-latest
197+
container:
198+
image: quay.io/pypa/manylinux2014_x86_64
199+
steps:
200+
- uses: actions/checkout@v2
201+
- uses: ./.github/actions/build-py
202+
with:
203+
package: surrogatemodel
204+
205+
test-py-surrogatemodel:
206+
if: github.event.pull_request.draft == false
207+
needs: [build-py-surrogatemodel, build-py-simulation]
208+
runs-on: ubuntu-latest
209+
steps:
210+
- uses: actions/checkout@v2
211+
- uses: ./.github/actions/test-py
212+
with:
213+
package: surrogatemodel
214+
194215
test-docs:
195216
if: github.event.pull_request.draft == false
196217
runs-on: ubuntu-latest

.github/workflows/minimal.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,25 @@ jobs:
8787
with:
8888
package: simulation
8989

90+
build-py-surrogatemodel:
91+
runs-on: ubuntu-latest
92+
container:
93+
image: quay.io/pypa/manylinux2014_x86_64
94+
steps:
95+
- uses: actions/checkout@v2
96+
- uses: ./.github/actions/build-py
97+
with:
98+
package: surrogatemodel
99+
100+
test-py-surrogatemodel:
101+
needs: [build-py-surrogatemodel, build-py-simulation]
102+
runs-on: ubuntu-latest
103+
steps:
104+
- uses: actions/checkout@v2
105+
- uses: ./.github/actions/test-py
106+
with:
107+
package: surrogatemodel
108+
90109
test-pylint-epidata:
91110
needs: build-py-epidata
92111
runs-on: ubuntu-latest
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
MEmilio Surrogate Model Package
2+
=======================
3+
This package contains machine learning based surrogate models that make predictions based on the MEmilio simulation models. Currently there are only surrogate models for ODE models. These simulations of these equation-based models are used for data generation. The goal is to create a powerful tool that predicts the dynamics faster than a simulation of an expert model, e.g., a metapopulation or agent-based model while still having acceptable errors with respect to the original simulations.
4+
5+
## Installation
6+
7+
Use the provided `setup.py` script install the package.
8+
To install the package, use the command (from the directory containing `setup.py`)
9+
10+
```bash
11+
pip install .
12+
```
13+
14+
For developement of code use
15+
16+
```bash
17+
pip install -e .[dev]
18+
```
19+
20+
Since we are running simulations to generate the data, the MEmilio `memilio-simulation` package (https://github.com/DLR-SC/memilio/tree/main/pycode/memilio-simulation) also needs to be installed.
21+
## Usage
22+
The package currently provides the following modules:
23+
24+
- `models`: models for different specific tasks
25+
Currently we have the following models:
26+
- `ode_secir_simple`: A simple model allowing for asymptomatic as well as symptomatic infection not stratified by age groups.
27+
28+
Each model folder contains the following files:
29+
- `data_generation`: data generated from expert model simulation.
30+
- `model`: training and evaluation of the model.
31+
- `network_architectures`: multiple network architectures are saved in this file.
32+
33+
34+
- `tests`: this file contains all tests
35+
36+
## Testing
37+
The package provides a test suite in `memilio/surrogatemodel_test`. To run the tests, simply run the following command.
38+
39+
```bash
40+
python -m unittest
41+
```
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2023 German Aerospace Center (DLR-SC)
3+
#
4+
# Authors: Daniel Abele
5+
#
6+
# Contact: Martin J. Kuehn <[email protected]>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
21+
"""
22+
MEmilio main namespace package.
23+
"""
24+
25+
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2023 German Aerospace Center (DLR-SC)
3+
#
4+
# Authors: Agatha Schmidt, Henrik Zunker
5+
#
6+
# Contact: Martin J. Kuehn <[email protected]>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
21+
"""
22+
Machine-learnt surrogate models for equation- or agent-based expert models.
23+
"""
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Simple SECIR model
2+
3+
This model is an very simplified application of the SECIR model implemented in https://github.com/DLR-SC/memilio/tree/main/cpp/models/ode_secir/ not stratified by age groups.
4+
The example is based on https://github.com/DLR-SC/memilio/tree/main/pycode/examples/simulation/secir_simple.py which uses python bindings to run the underlying C++ code.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2023 German Aerospace Center (DLR-SC)
3+
#
4+
# Authors: Agatha Schmidt, Henrik Zunker
5+
#
6+
# Contact: Martin J. Kuehn <[email protected]>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
21+
"""
22+
A surrogate model for a simple SECIR model allowing for asymptomatic as well as symptomatic infection not stratified by age groups.
23+
"""
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
#############################################################################
2+
# Copyright (C) 2020-2023 German Aerospace Center (DLR-SC)
3+
#
4+
# Authors: Agatha Schmidt, Henrik Zunker
5+
#
6+
# Contact: Martin J. Kuehn <[email protected]>
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
#############################################################################
20+
import copy
21+
import os
22+
import pickle
23+
import random
24+
from datetime import date
25+
26+
import numpy as np
27+
import pandas as pd
28+
import tensorflow as tf
29+
from progress.bar import Bar
30+
from sklearn.preprocessing import FunctionTransformer
31+
32+
from memilio.simulation import (ContactMatrix, Damping, LogLevel,
33+
UncertainContactMatrix, set_log_level)
34+
from memilio.simulation.secir import (AgeGroup, Index_InfectionState,
35+
InfectionState, Model, Simulation,
36+
interpolate_simulation_result, simulate)
37+
38+
39+
def run_secir_simulation(days):
40+
"""! Uses an ODE SECIR model allowing for asymptomatic infection. The model is not stratified by region or demographic properties such as age.
41+
Virus-specific parameters are fixed and initial number of persons in the particular infection states are chosen randomly from defined ranges.
42+
43+
@param Days Describes how many days we simulate within a single run.
44+
@return List containing the populations in each compartment for each day of the simulation.
45+
"""
46+
set_log_level(LogLevel.Off)
47+
48+
populations = [50_000]
49+
start_day = 1
50+
start_month = 1
51+
start_year = 2019
52+
dt = 0.1
53+
num_groups = 1
54+
55+
# Initialize Parameters
56+
model = Model(1)
57+
58+
A0 = AgeGroup(0)
59+
60+
# Set parameters
61+
# Compartment transition duration
62+
model.parameters.IncubationTime[A0] = 5.2
63+
model.parameters.TimeInfectedSymptoms[A0] = 6.
64+
model.parameters.SerialInterval[A0] = 4.2
65+
model.parameters.TimeInfectedSevere[A0] = 12.
66+
model.parameters.TimeInfectedCritical[A0] = 8.
67+
68+
# Initial number of people in each compartment with random numbers
69+
model.populations[A0, Index_InfectionState(
70+
InfectionState.Exposed)] = 60 * random.uniform(0.2, 1)
71+
model.populations[A0, Index_InfectionState(
72+
InfectionState.InfectedNoSymptoms)] = 55 * random.uniform(0.2, 1)
73+
model.populations[A0, Index_InfectionState(
74+
InfectionState.InfectedSymptoms)] = 50 * random.uniform(0.2, 1)
75+
model.populations[A0, Index_InfectionState(
76+
InfectionState.InfectedSevere)] = 12 * random.uniform(0.2, 1)
77+
model.populations[A0, Index_InfectionState(
78+
InfectionState.InfectedCritical)] = 3 * random.uniform(0.2, 1)
79+
model.populations[A0, Index_InfectionState(
80+
InfectionState.Recovered)] = 50 * random.random()
81+
model.populations[A0, Index_InfectionState(InfectionState.Dead)] = 0
82+
model.populations.set_difference_from_total(
83+
(A0, Index_InfectionState(InfectionState.Susceptible)), populations[0])
84+
85+
# Compartment transition propabilities
86+
model.parameters.RelativeTransmissionNoSymptoms[A0] = 0.5
87+
model.parameters.TransmissionProbabilityOnContact[A0] = 0.1
88+
model.parameters.RecoveredPerInfectedNoSymptoms[A0] = 0.09
89+
model.parameters.RiskOfInfectionFromSymptomatic[A0] = 0.25
90+
model.parameters.SeverePerInfectedSymptoms[A0] = 0.2
91+
model.parameters.CriticalPerSevere[A0] = 0.25
92+
model.parameters.DeathsPerCritical[A0] = 0.3
93+
# twice the value of RiskOfInfectionFromSymptomatic
94+
model.parameters.MaxRiskOfInfectionFromSymptomatic[A0] = 0.5
95+
96+
model.parameters.StartDay = (
97+
date(start_year, start_month, start_day) - date(start_year, 1, 1)).days
98+
99+
model.parameters.ContactPatterns.cont_freq_mat[0].baseline = np.ones(
100+
(num_groups, num_groups)) * 10
101+
model.parameters.ContactPatterns.cont_freq_mat[0].minimum = np.ones(
102+
(num_groups, num_groups)) * 0
103+
104+
# Apply mathematical constraints to parameters
105+
model.apply_constraints()
106+
107+
# Run Simulation
108+
result = simulate(0, days, dt, model)
109+
# Interpolate simulation result on days time scale
110+
result = interpolate_simulation_result(result)
111+
112+
# Using an array instead of a list to avoid problems with references
113+
result_array = result.as_ndarray()
114+
dataset = []
115+
# Omit first column, as the time points are not of interest here.
116+
dataset_entries = copy.deepcopy(result_array[1:, :].transpose())
117+
118+
return dataset_entries.tolist()
119+
120+
121+
def generate_data(
122+
num_runs, path, input_width, label_width, normalize=True,
123+
save_data=True):
124+
"""! Generate data sets of num_runs many equation-based model simulations and transforms the computed results by a log(1+x) transformation.
125+
Divides the results in input and label data sets and returns them as a dictionary of two TensorFlow Stacks.
126+
127+
In general, we have 8 different compartments. If we choose,
128+
input_width = 5 and label_width = 20, the dataset has
129+
- input with dimension 5 x 8
130+
- labels with dimension 20 x 8
131+
132+
@param num_runs Number of times, the function run_secir_simulation is called.
133+
@param path Path, where the dataset is saved to.
134+
@param input_width Int value that defines the number of time series used for the input.
135+
@param label_width Int value that defines the size of the labels.
136+
@param normalize [Default: true] Option to transform dataset by logarithmic normalization.
137+
@param save_data [Default: true] Option to save the dataset.
138+
@return Data dictionary of input and label data sets.
139+
"""
140+
data = {
141+
"inputs": [],
142+
"labels": []
143+
}
144+
145+
# The number of days is the same as the sum of input and label width.
146+
# Since the first day of the input is day 0, we still need to subtract 1.
147+
days = input_width + label_width - 1
148+
149+
# show progess in terminal for longer runs
150+
# Due to the random structure, theres currently no need to shuffle the data
151+
bar = Bar('Number of Runs done', max=num_runs)
152+
for _ in range(0, num_runs):
153+
data_run = run_secir_simulation(days)
154+
data['inputs'].append(data_run[:input_width])
155+
data['labels'].append(data_run[input_width:])
156+
bar.next()
157+
bar.finish()
158+
159+
if normalize:
160+
# logarithmic normalization
161+
transformer = FunctionTransformer(np.log1p, validate=True)
162+
inputs = np.asarray(data['inputs']).transpose(2, 0, 1).reshape(8, -1)
163+
scaled_inputs = transformer.transform(inputs)
164+
scaled_inputs = scaled_inputs.transpose().reshape(num_runs, input_width, 8)
165+
scaled_inputs_list = scaled_inputs.tolist()
166+
167+
labels = np.asarray(data['labels']).transpose(2, 0, 1).reshape(8, -1)
168+
scaled_labels = transformer.transform(labels)
169+
scaled_labels = scaled_labels.transpose().reshape(num_runs, label_width, 8)
170+
scaled_labels_list = scaled_labels.tolist()
171+
172+
# cast dfs to tensors
173+
data['inputs'] = tf.stack(scaled_inputs_list)
174+
data['labels'] = tf.stack(scaled_labels_list)
175+
176+
if save_data:
177+
# check if data directory exists. If necessary, create it.
178+
if not os.path.isdir(path):
179+
os.mkdir(path)
180+
181+
# save dict to json file
182+
with open(os.path.join(path, 'data_secir_simple.pickle'), 'wb') as f:
183+
pickle.dump(data, f)
184+
return data
185+
186+
187+
if __name__ == "__main__":
188+
# Store data relative to current file two levels higher.
189+
path = os.path.dirname(os.path.realpath(__file__))
190+
path_data = os.path.join(os.path.dirname(os.path.realpath(
191+
os.path.dirname(os.path.realpath(path)))), 'data')
192+
193+
input_width = 5
194+
label_width = 30
195+
num_runs = 1000
196+
data = generate_data(num_runs, path_data, input_width,
197+
label_width)

0 commit comments

Comments
 (0)