Skip to content

Commit febb82d

Browse files
authored
1019 Adjust memilio-generation for flow models and automatic differentiation (#1153)
Extend the generation with scanner now checks for Flowmodel and Compartmentalmodel, pymio includes fixed in template_string, config.json significantly reduced, added simulate and simulate_flows, added interpolate_simulation_result and interpolate_ensemble_results, and added scalartype in intermediate_representation to set a datatype for bindings.
1 parent 672cbc1 commit febb82d

18 files changed

Lines changed: 577 additions & 282 deletions

pycode/memilio-generation/README.md

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,9 @@ pip install .
2929

3030
## Usage
3131

32-
The package provides an example script on how to use it in `memilio/tools`. The example uses the ode_seir model from the [C++ Library](../../cpp/models/ode_seir/README.md).
32+
The package provides an example script and README on how to use it in `memilio/tools`. The example uses the ode_seir model from the [C++ Library](../../cpp/models/ode_seir/README.md).
3333

34-
Before running the example you have to do these steps of setup:
35-
- Change [config.json.txt](./memilio/tools/config.json.txt). You can find a documentation in the ScannerConfig class.
36-
- Check if the parameters set in __post_init__() of the [ScannerConfig class](./memilio/generation/scanner_config.py) match with the cpp-class names.
37-
38-
Example:
39-
After processing as described in the previous paragraph, run the example with the command (path according to the current folder):
40-
41-
```bash
42-
python memilio/tools/example_oseir.py
43-
```
44-
45-
When working on a new model you can copy the example script and add an additional segment to the config.json.txt. The setup works similar to the example. Additionaly you can print the AST of your model into a file (Usefull for development/debugging).
34+
You can print the AST of your model into a file (Usefull for development/debugging).
4635

4736
## Testing
4837

@@ -55,11 +44,11 @@ python -m unittest
5544
## Development
5645

5746
When implementing new model features you can follow these steps:
58-
- Add necessary configurations to [config txt-file](./memilio/tools/config.json.txt) and add corresponding attributes to the [ScannerConfig class](./memilio/generation/scanner_config.py).
59-
- For the features you want to implement, find the nodes in the abstract syntax tree (AST) (use method Scanner.output_ast_file(); see the example in tools/).
47+
- For the features you want to implement, find the nodes in the abstract syntax tree (AST) (use method aviz.output_ast_formatted(); see the example in tools/).
6048
- Add the extraction of those features. Therefore you need to change the "check_..."-methods corresponding to the CursorKind of your nodes in the [Scanner class](./memilio/generation/scanner.py). If there is no corresponding "check_..."-method you need to write a new one and add it to the switch-method (scanner.switch_node_kind()).
6149
- Extend the [IntermediateRepresentation](./memilio/generation/intermediate_representation.py) for the new model features.
6250
- Adjust the [cpp-template](./memilio//generation/template/template_ode_cpp.txt) and the [string-template-methods](./memilio/generation/template/template_ode_string.py). If needed, use new identifiers and write new string-template-methods for them.
6351
- Adjust the substitution dictionaries in the [Generator class](./memilio/generation/generator.py).
6452
- Write new/Adjust script in the [tool folder](./memilio/tools/) for the model and try to run.
53+
- Add new strings in the [Default dict](/pycode/memilio-generation/memilio/generation/default_generation_dict.py)
6554
- Update [tests](./memilio/generation_test/).

pycode/memilio-generation/memilio/generation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
from .scanner import Scanner
2727
from .scanner_config import ScannerConfig
2828
from .ast import AST
29+
from .default_generation_dict import default_dict

pycode/memilio-generation/memilio/generation/ast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _assing_ast_with_ids(self, cursor: Cursor) -> None:
113113
:param cursor: Cursor:
114114
115115
"""
116-
# assing_ids umschreiben -> mapping
116+
117117
self.cursor_id += 1
118118
id = self.cursor_id
119119
self.id_to_val[id] = cursor
@@ -137,7 +137,7 @@ def root_cursor(self):
137137
def get_node_id(self, cursor: Cursor) -> int:
138138
""" Returns the id of the current node.
139139
140-
Extracs the key from the current cursor from the dictonary id_to_val
140+
Extracts the key from the current cursor from the dictonary id_to_val
141141
142142
:param cursor: The current node of the AST as a cursor object from libclang.
143143
:param cursor: Cursor:
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
default_dict = {
2+
"model": "Model",
3+
"agegroup": "AgeGroup",
4+
"emptystring": "",
5+
"simulation": "Simulation",
6+
"flowmodel": "FlowModel",
7+
"compartmentalmodel": "CompartmentaModel",
8+
"modelfile": "model.h",
9+
"infectionstatefile": "infection_state.h",
10+
"parameterspacefile": "parameter_space.h",
11+
"analyzeresultfile": "analyze_result.h",
12+
"namespace": "namespace",
13+
"mio": "mio"
14+
}

pycode/memilio-generation/memilio/generation/generator.py

Lines changed: 23 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -51,50 +51,30 @@ def create_substitutions(
5151
:param self: Self:
5252
5353
"""
54-
# create the substitutions with a given intermed_repr
5554

56-
# substititutions for the py-file
57-
self.substitutions_py = {}
58-
self.substitutions_py["python_module_name"] = intermed_repr.python_module_name
59-
60-
# substititutions for the cpp-file
61-
self.substitutions_cpp = {}
62-
63-
self.substitutions_cpp["namespace"] = intermed_repr.namespace
64-
self.substitutions_cpp["model_class_name"] = intermed_repr.model_class
65-
self.substitutions_cpp["model_base"] = intermed_repr.model_base[0]
66-
self.substitutions_cpp["model_base_templates"] = intermed_repr.model_base[1][0] + \
67-
", " + intermed_repr.model_base[2][0] + \
68-
", " + intermed_repr.model_base[3][0]
69-
self.substitutions_cpp["python_module_name"] = intermed_repr.python_module_name
70-
self.substitutions_cpp["parameterset"] = intermed_repr.parameterset
71-
72-
self.substitutions_cpp["includes"] = StringTemplates.includes(
73-
intermed_repr)
74-
self.substitutions_cpp["pretty_name_function"] = StringTemplates.pretty_name_function(
75-
intermed_repr)
76-
self.substitutions_cpp["population_enums"] = StringTemplates.population_enums(
77-
intermed_repr)
78-
self.substitutions_cpp["model_init"] = StringTemplates.model_init(
79-
intermed_repr)
80-
self.substitutions_cpp["population"] = StringTemplates.population(
81-
intermed_repr)
82-
83-
# optional substitution strings for model with agegroups
84-
self.substitutions_cpp["parameterset_indexing"] = StringTemplates.parameterset_indexing(
85-
intermed_repr)
86-
self.substitutions_cpp["parameterset_wrapper"] = StringTemplates.parameterset_wrapper(
87-
intermed_repr)
88-
self.substitutions_cpp["age_group"] = StringTemplates.age_group(
89-
intermed_repr)
90-
91-
# optional substitution strings for model with simulation class
92-
self.substitutions_cpp["simulation"] = StringTemplates.simulation(
93-
intermed_repr)
94-
self.substitutions_cpp["simulation_graph"] = StringTemplates.simulation_graph(
95-
intermed_repr)
96-
self.substitutions_cpp["simulation_vector_definition"] = StringTemplates.simulation_vector_definition(
97-
intermed_repr)
55+
self.substitutions_py = {
56+
"python_module_name": intermed_repr.python_module_name
57+
}
58+
59+
self.substitutions_cpp = {
60+
"namespace": intermed_repr.namespace,
61+
"model_class_name": intermed_repr.model_class,
62+
"model_base": intermed_repr.model_base[0],
63+
"model_base_templates": intermed_repr.model_base_templates,
64+
"python_module_name": intermed_repr.python_module_name,
65+
"parameterset": intermed_repr.parameterset,
66+
"includes": StringTemplates.includes(intermed_repr),
67+
"pretty_name_function": StringTemplates.pretty_name_function(intermed_repr),
68+
"population_enums": StringTemplates.population_enums(intermed_repr),
69+
"model_init": StringTemplates.model_init(intermed_repr),
70+
"population": StringTemplates.population(intermed_repr),
71+
"parameterset_indexing": StringTemplates.parameterset_indexing(intermed_repr),
72+
"parameterset_wrapper": StringTemplates.parameterset_wrapper(intermed_repr),
73+
"simulation": StringTemplates.simulation(intermed_repr),
74+
"ScalarType": StringTemplates.ScalarType(intermed_repr),
75+
"draw_sample": StringTemplates.draw_sample(intermed_repr),
76+
"simulation_vector_definition": StringTemplates.simulation_vector_definition(intermed_repr)
77+
}
9878

9979
def generate_files(
10080
self: Self, intermed_repr: IntermediateRepresentation) -> None:
@@ -106,19 +86,16 @@ def generate_files(
10686
:param self: Self:
10787
10888
"""
109-
# read templates
11089
with open(os.path.join(intermed_repr.python_generation_module_path,
11190
"memilio/generation/template/template_py.txt")) as t:
11291
template_py = string.Template(t.read())
11392
with open(os.path.join(intermed_repr.python_generation_module_path,
11493
"memilio/generation/template/template_cpp.txt")) as t:
11594
template_cpp = string.Template(t.read())
11695

117-
# substitue identifiers
11896
output_py = template_py.safe_substitute(**self.substitutions_py)
11997
output_cpp = template_cpp.safe_substitute(**self.substitutions_cpp)
12098

121-
# print code into files
12299
py_filename = intermed_repr.python_module_name + ".py"
123100
cpp_filename = intermed_repr.python_module_name + ".cpp"
124101
with open(os.path.join(intermed_repr.target_folder, py_filename), "w") as output:

pycode/memilio-generation/memilio/generation/graph_visualization.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def file_writer(level: int, cursor_label: str) -> None:
104104
_output_cursor_and_children(cursor, ast, file_writer)
105105

106106
output_path = os.path.abspath(f"{output_file_name}")
107-
logging.info(f"AST-formated written to {output_path}")
107+
logging.info(f"AST-formatted written to {output_path}")
108108

109109

110110
def indent(level: int) -> str:
@@ -135,13 +135,14 @@ def _output_cursor_and_children(cursor: Cursor, ast: AST, writer: Callable[[int,
135135

136136
cursor_kind = f"<CursorKind.{cursor.kind.name}>"
137137
file_path = cursor.location.file.name if cursor.location.file else ""
138+
line_number = cursor.location.line if cursor.location.file else ""
138139

139140
if cursor.spelling:
140141
cursor_label = (f'ID:{cursor_id} {cursor.spelling} '
141142
f'{cursor_kind} '
142-
f'{file_path}')
143+
f'{file_path}:{line_number}')
143144
else:
144-
cursor_label = f'ID:{cursor_id} {cursor_kind} {file_path}'
145+
cursor_label = f'ID:{cursor_id} {cursor_kind} {file_path}:{line_number}'
145146

146147
writer(level, cursor_label)
147148

@@ -165,8 +166,7 @@ def _output_cursor_and_children_graphviz_digraph(cursor: Cursor, graph: Digraph,
165166
if current_d > max_d:
166167
return
167168

168-
node_label = f"{cursor.kind.name}{
169-
newline()}({cursor.spelling})" if cursor.spelling else cursor.kind.name
169+
node_label = f"{cursor.kind.name}{newline()}({cursor.spelling})" if cursor.spelling else cursor.kind.name
170170

171171
current_node = f"{cursor.kind.name}_{cursor.hash}"
172172

@@ -176,8 +176,7 @@ def _output_cursor_and_children_graphviz_digraph(cursor: Cursor, graph: Digraph,
176176
graph.edge(parent_node, current_node)
177177

178178
if cursor.kind.is_reference():
179-
referenced_label = f"ref_to_{cursor.referenced.kind.name}{
180-
newline()}({cursor.referenced.spelling})"
179+
referenced_label = f"ref_to_{cursor.referenced.kind.name}{newline()}({cursor.referenced.spelling})"
181180
referenced_node = f"ref_{cursor.referenced.hash}"
182181
graph.node(referenced_node, label=referenced_label)
183182
graph.edge(current_node, referenced_node)

pycode/memilio-generation/memilio/generation/intermediate_representation.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,31 @@
2525
from typing import Any, Dict, Union
2626

2727
from typing_extensions import Self
28+
from memilio.generation import Generator
2829

2930

3031
@dataclass
3132
class IntermediateRepresentation:
32-
"""Dataclass storing the model features. Serves as interface between Scanner and Generator."""
33-
namespace: str = None
34-
model_class: str = None
35-
python_module_name: str = None
36-
parameterset: str = None
37-
parameterset_wrapper: str = None
38-
simulation_class: str = None
39-
python_generation_module_path: str = None
40-
target_folder: str = None
33+
"""
34+
Dataclass storing the model features. Serves as interface between Scanner and Generator.
35+
"""
36+
namespace: str = ""
37+
model_class: str = ""
38+
python_module_name: str = ""
39+
parameterset: str = ""
40+
parameterset_wrapper: str = ""
41+
simulation: bool = False
42+
is_compartmentalmodel: bool = False
43+
is_flowmodel: bool = False
44+
has_age_group: bool = False
45+
has_draw_sample: bool = False
46+
scalartype: str = "double"
47+
python_generation_module_path: str = ""
48+
target_folder: str = ""
4149
enum_populations: dict = field(default_factory=dict)
4250
model_init: list = field(default_factory=list)
4351
model_base: list = field(default_factory=list)
52+
model_base_templates: str = ""
4453
population_groups: list = field(default_factory=list)
4554
include_list: list = field(default_factory=list)
4655
age_group: dict = field(default_factory=dict)
@@ -57,6 +66,17 @@ def set_attribute(self: Self, attribute_name: str, value: Any) -> None:
5766
"""
5867
self.__setattr__(attribute_name, value)
5968

69+
def check_model_base(self: Self) -> None:
70+
"""
71+
Check if the model_base is set. If not, set it to the model_class.
72+
"""
73+
if len(self.model_base) > 0:
74+
self.model_base_templates = ", ".join(
75+
entry[0] for entry in self.model_base if len(entry) > 0
76+
)
77+
else:
78+
raise IndexError("model_base is empty. No base classes found.")
79+
6080
def check_complete_data(self: Self, optional: Dict
6181
[str, Union[str, bool]]) -> None:
6282
"""Check for missing data in the IntermediateRepresentation.

0 commit comments

Comments
 (0)