Skip to content

Commit 9e825cf

Browse files
1486 Modernizations and minor fixes in core C++ library (#1487)
- Replace is_expression_valid clauses using the requires keyword. - Replace several SFINAE constructions with concepts or if constexpr constructions. - Update some includes and include guards. - Refactor mio::Range to adhere more closely to std::ranges and to simplify usage. Co-authored-by: annawendler <[email protected]>
1 parent 12db98b commit 9e825cf

90 files changed

Lines changed: 970 additions & 1703 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cpp/benchmarks/secir_ageres_setups.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace mio
3030
{
3131
namespace benchmark
3232
{
33-
namespace detail
33+
namespace details
3434
{
3535
/**
3636
* @brief Helper function to create a secir model with consistent setup for use in benchmarking.
@@ -89,7 +89,7 @@ mio::osecir::Model<ScalarType> make_model(int num)
8989

9090
return model;
9191
}
92-
} // namespace detail
92+
} // namespace details
9393

9494
namespace model
9595
{
@@ -98,7 +98,7 @@ namespace model
9898
*/
9999
mio::osecir::Model<ScalarType> SecirAgeres(size_t num_agegroups)
100100
{
101-
mio::osecir::Model<ScalarType> model = mio::benchmark::detail::make_model(num_agegroups);
101+
mio::osecir::Model<ScalarType> model = mio::benchmark::details::make_model(num_agegroups);
102102

103103
auto nb_groups = model.parameters.get_num_groups();
104104
ScalarType cont_freq = 10, fact = 1.0 / (ScalarType)(size_t)nb_groups;
@@ -117,7 +117,7 @@ mio::osecir::Model<ScalarType> SecirAgeres(size_t num_agegroups)
117117
*/
118118
mio::osecir::Model<ScalarType> SecirAgeresDampings(size_t num_agegroups)
119119
{
120-
mio::osecir::Model<ScalarType> model = mio::benchmark::detail::make_model(num_agegroups);
120+
mio::osecir::Model<ScalarType> model = mio::benchmark::details::make_model(num_agegroups);
121121

122122
auto nb_groups = model.parameters.get_num_groups();
123123
ScalarType cont_freq = 10, fact = 1.0 / (ScalarType)(size_t)nb_groups;
@@ -145,7 +145,7 @@ mio::osecir::Model<ScalarType> SecirAgeresDampings(size_t num_agegroups)
145145
*/
146146
mio::osecir::Model<ScalarType> SecirAgeresAbsurdDampings(size_t num_agegroups)
147147
{
148-
mio::osecir::Model<ScalarType> model = mio::benchmark::detail::make_model(num_agegroups);
148+
mio::osecir::Model<ScalarType> model = mio::benchmark::details::make_model(num_agegroups);
149149

150150
auto nb_groups = model.parameters.get_num_groups();
151151
ScalarType cont_freq = 10, fact = 1.0 / (ScalarType)(size_t)nb_groups;

cpp/memilio/ad/ad.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,26 @@
2727
#include <cmath>
2828
#include <limits>
2929

30+
namespace ad
31+
{
32+
namespace internal
33+
{
34+
35+
/**
36+
* @brief Format AD types (like ad::gt1s<double>::type) using their value for logging with spdlog.
37+
*
38+
* If derivative information is needed as well, use `ad::derivative(...)` or define a `fmt::formatter<...>`.
39+
*/
40+
template <class FP, class DataHandler>
41+
const FP& format_as(const active_type<FP, DataHandler>& ad_type)
42+
{
43+
// Note: the format_as function needs to be in the same namespace as the value it takes
44+
return value(ad_type);
45+
}
46+
47+
} // namespace internal
48+
} // namespace ad
49+
3050
// Allow std::numeric_limits to work with AD types.
3151
template <class FP, class DataHandler>
3252
struct std::numeric_limits<ad::internal::active_type<FP, DataHandler>> : public numeric_limits<FP> {

cpp/memilio/compartments/compartmental_model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
#ifndef MIO_COMPARTMENTS_COMPARTMENTAL_MODEL_H
2121
#define MIO_COMPARTMENTS_COMPARTMENTAL_MODEL_H
2222

23-
#include "memilio/config.h"
24-
#include "memilio/math/eigen.h"
23+
#include "memilio/config.h" // IWYU pragma: keep
24+
#include "memilio/math/eigen.h" // IWYU pragma: keep
25+
2526
#include <concepts>
2627

2728
namespace mio

cpp/memilio/compartments/feedback_simulation.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,11 @@
2020
#ifndef MIO_COMPARTMENTS_FEEDBACK_SIMULATION_H
2121
#define MIO_COMPARTMENTS_FEEDBACK_SIMULATION_H
2222

23-
#include <cassert>
24-
#include "memilio/compartments/simulation.h"
25-
#include "memilio/utils/time_series.h"
26-
#include "memilio/utils/parameter_set.h"
2723
#include "memilio/epidemiology/age_group.h"
28-
#include "memilio/utils/uncertain_value.h"
2924
#include "memilio/epidemiology/damping_sampling.h"
25+
#include "memilio/utils/parameter_set.h"
26+
#include "memilio/utils/time_series.h"
27+
#include "memilio/utils/uncertain_value.h"
3028

3129
namespace mio
3230
{

cpp/memilio/compartments/flow_model.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include "memilio/compartments/compartmental_model.h"
2525
#include "memilio/utils/index_range.h"
2626
#include "memilio/utils/flow.h"
27-
#include "memilio/utils/type_list.h"
27+
#include "memilio/utils/type_list.h" // IWYU pragma: keep for easier flow definitions
2828

2929
namespace mio
3030
{
@@ -37,8 +37,8 @@ namespace details
3737
// First a list of tuples is generated for each Tag in Tags, where the tuple is either of type tuple<Tag>, or if
3838
// Tag == OmittedTag, of type tuple<>. This list is then concatenated, effectively removing OmittedTag.
3939
template <class OmittedTag, class... Tags>
40-
decltype(std::tuple_cat(std::declval<typename std::conditional<std::is_same<OmittedTag, Tags>::value, std::tuple<>,
41-
std::tuple<Tags>>::type>()...))
40+
decltype(std::tuple_cat(
41+
std::declval<std::conditional_t<std::is_same_v<OmittedTag, Tags>, std::tuple<>, std::tuple<Tags>>>()...))
4242
filter_tuple(std::tuple<Tags...>);
4343

4444
// Function declaration used to replace type T by std::tuple.
@@ -118,7 +118,7 @@ class FlowModel : public CompartmentalModel<FP, Comp, Pop, Params>
118118
get_rhs_impl(flows, dydt, Index<>{});
119119
}
120120
else {
121-
for (FlowIndex I : make_index_range(reduce_index<FlowIndex>(this->populations.size()))) {
121+
for (FlowIndex I : reduce_index<FlowIndex>(this->populations.size())) {
122122
get_rhs_impl(flows, dydt, I);
123123
}
124124
}
@@ -136,7 +136,7 @@ class FlowModel : public CompartmentalModel<FP, Comp, Pop, Params>
136136
* @param[out] dydt A reference to the calculated output.
137137
*/
138138
void get_derivatives(Eigen::Ref<const Eigen::VectorX<FP>> pop, Eigen::Ref<const Eigen::VectorX<FP>> y, FP t,
139-
Eigen::Ref<Eigen::VectorX<FP>> dydt) const override final
139+
Eigen::Ref<Eigen::VectorX<FP>> dydt) const final
140140
{
141141
m_flow_values.setZero();
142142
get_flows(pop, y, t, m_flow_values);
@@ -200,7 +200,7 @@ class FlowModel : public CompartmentalModel<FP, Comp, Pop, Params>
200200
template <Comp Source, Comp Target>
201201
constexpr size_t get_flat_flow_index() const
202202
{
203-
static_assert(std::is_same<FlowIndex, Index<>>::value, "Other indices must be specified");
203+
static_assert(std::is_same_v<FlowIndex, Index<>>, "Other indices must be specified");
204204
return index_of_type_v<Flow<Source, Target>, Flows>;
205205
}
206206

cpp/memilio/compartments/parameter_studies.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222

2323
#include "memilio/io/binary_serializer.h"
2424
#include "memilio/io/io.h"
25-
#include "memilio/mobility/graph_simulation.h"
2625
#include "memilio/utils/logging.h"
27-
#include "memilio/utils/metaprogramming.h"
2826
#include "memilio/utils/miompi.h"
2927
#include "memilio/utils/random_number_generator.h"
3028
#include "memilio/mobility/metapopulation_mobility_instant.h"
@@ -160,14 +158,11 @@ class ParameterStudy
160158
run(CreateSimulationFunction&& create_simulation, ProcessSimulationResultFunction&& process_simulation_result)
161159
{
162160
using ResultT = EnsembleResultT<CreateSimulationFunction, ProcessSimulationResultFunction>;
163-
int num_procs, rank;
161+
int num_procs = 1, rank = 0;
164162

165163
#ifdef MEMILIO_ENABLE_MPI
166164
MPI_Comm_size(mpi::get_world(), &num_procs);
167165
MPI_Comm_rank(mpi::get_world(), &rank);
168-
#else
169-
num_procs = 1;
170-
rank = 0;
171166
#endif
172167

173168
//The ParameterDistributions used for sampling parameters use thread_local_rng()
@@ -176,9 +171,8 @@ class ParameterStudy
176171
m_rng.synchronize();
177172

178173
std::vector<size_t> run_distribution = distribute_runs(m_num_runs, num_procs);
179-
size_t start_run_idx =
180-
std::accumulate(run_distribution.begin(), run_distribution.begin() + size_t(rank), size_t(0));
181-
size_t end_run_idx = start_run_idx + run_distribution[size_t(rank)];
174+
size_t start_run_idx = std::accumulate(run_distribution.begin(), run_distribution.begin() + rank, size_t(0));
175+
size_t end_run_idx = start_run_idx + run_distribution[rank];
182176

183177
if constexpr (std::is_void_v<ResultT>) {
184178
// if the processor returns nothing, there is nothing to synchronize

cpp/memilio/compartments/simulation.h

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include "memilio/compartments/compartmental_model.h"
2424
#include "memilio/compartments/simulation_base.h"
2525
#include "memilio/math/integrator.h"
26-
#include "memilio/utils/metaprogramming.h"
2726
#include "memilio/math/stepper_wrapper.h"
2827
#include "memilio/utils/time_series.h"
2928

@@ -75,27 +74,15 @@ class Simulation : public details::SimulationBase<FP, M, OdeIntegrator<FP>>
7574
};
7675

7776
/**
78-
* Defines the return type of the `advance` member function of a type.
79-
* Template is invalid if this member function does not exist.
80-
*
81-
* @tparam FP floating point type, e.g., double
82-
* @tparam Sim a compartment model simulation type.
83-
*/
84-
template <typename FP, class Sim>
85-
using advance_expr_t = decltype(std::declval<Sim>().advance(std::declval<FP>()));
86-
87-
/**
88-
* Template meta function to check if a type is a compartment model simulation.
89-
* Defines a static constant of name `value`.
90-
* The constant `value` will be equal to true if Sim is a valid compartment simulation type.
91-
* Otherwise, `value` will be equal to false.
92-
* @tparam FP floating point type, e.g., double
93-
* @tparam Sim a type that may or may not be a compartment model simulation.
77+
* Concept to check if a type is a simulation for a compartmental model.
78+
* @tparam Simulation A type that may or may not be a compartmental model simulation.
79+
* @tparam FP A floating point type, e.g., double.
9480
*/
95-
template <typename FP, class Sim>
96-
using is_compartment_model_simulation =
97-
std::integral_constant<bool, (is_expression_valid<advance_expr_t, FP, Sim>::value &&
98-
IsCompartmentalModel<typename Sim::Model, FP>)>;
81+
template <class Simulation, typename FP>
82+
concept IsCompartmentalModelSimulation = requires(Simulation simulation, FP t) {
83+
requires IsCompartmentalModel<typename Simulation::Model, FP>;
84+
simulation.advance(t);
85+
};
9986

10087
/**
10188
* @brief Run a Simulation of a CompartmentalModel.

cpp/memilio/compartments/stochastic_model.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
#include "memilio/compartments/compartmental_model.h"
2424
#include "memilio/compartments/flow_model.h"
25-
#include "memilio/utils/metaprogramming.h"
2625
#include "memilio/utils/random_number_generator.h"
2726

2827
namespace mio

cpp/memilio/data/analyze_result.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
* limitations under the License.
1919
*/
2020
#include "memilio/data/analyze_result.h"
21-
#include "memilio/math/interpolation.h"
22-
23-
#include <algorithm>
24-
#include <cassert>
2521

2622
namespace mio
2723
{

cpp/memilio/data/analyze_result.h

Lines changed: 38 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#define MEMILIO_DATA_ANALYZE_RESULT_H
2222

2323
#include "memilio/config.h"
24+
#include "memilio/utils/logging.h"
2425
#include "memilio/utils/time_series.h"
2526
#include "memilio/mobility/metapopulation_mobility_instant.h"
2627
#include "memilio/math/interpolation.h"
@@ -376,81 +377,51 @@ template <class FP>
376377
IOResult<TimeSeries<FP>> merge_time_series(const TimeSeries<FP>& ts1, const TimeSeries<FP>& ts2,
377378
bool add_values = false)
378379
{
379-
TimeSeries<FP> merged_ts(ts1.get_num_elements());
380380
if (ts1.get_num_elements() != ts2.get_num_elements()) {
381381
log_error("TimeSeries have a different number of elements.");
382382
return failure(mio::StatusCode::InvalidValue);
383383
}
384-
else {
385-
Eigen::Index t1_iterator = 0;
386-
Eigen::Index t2_iterator = 0;
387-
bool t1_finished = false;
388-
bool t2_finished = false;
389-
while (!t1_finished || !t2_finished) {
390-
if (!t1_finished) {
391-
if (ts1.get_time(t1_iterator) < ts2.get_time(t2_iterator) ||
392-
t2_finished) { // Current time point of first TimeSeries is smaller than current time point of second TimeSeries or second TimeSeries has already been copied entirely
393-
merged_ts.add_time_point(ts1.get_time(t1_iterator), ts1.get_value(t1_iterator));
394-
t1_iterator += 1;
395-
}
396-
else if (!t2_finished && ts1.get_time(t1_iterator) ==
397-
ts2.get_time(t2_iterator)) { // Both TimeSeries have the current time point
398-
if (add_values) {
399-
merged_ts.add_time_point(ts1.get_time(t1_iterator),
400-
ts1.get_value(t1_iterator) + ts2.get_value(t2_iterator));
401-
}
402-
else {
403-
merged_ts.add_time_point(ts1.get_time(t1_iterator), ts1.get_value(t1_iterator));
404-
log_warning("Both TimeSeries have values for t={}. The value of the first TimeSeries is used",
405-
ts1.get_time(t1_iterator));
406-
}
407-
t1_iterator += 1;
408-
t2_iterator += 1;
409-
if (t2_iterator >=
410-
ts2.get_num_time_points()) { // Check if all values of second TimeSeries have been copied
411-
t2_finished = true;
412-
t2_iterator = ts2.get_num_time_points() - 1;
413-
}
414-
}
415-
if (t1_iterator >=
416-
ts1.get_num_time_points()) { // Check if all values of first TimeSeries have been copied
417-
t1_finished = true;
418-
t1_iterator = ts1.get_num_time_points() - 1;
419-
}
384+
if (!ts1.is_strictly_monotonic() || !ts2.is_strictly_monotonic()) {
385+
log_error("TimeSeries need to have strictly monotonic time points to be merged.");
386+
return failure(mio::StatusCode::InvalidValue);
387+
}
388+
Eigen::Index t1_iterator = 0;
389+
Eigen::Index t2_iterator = 0;
390+
const Eigen::Index t1_size = ts1.get_num_time_points();
391+
const Eigen::Index t2_size = ts2.get_num_time_points();
392+
TimeSeries<FP> merged_ts(ts1.get_num_elements());
393+
merged_ts.reserve(t1_size + t2_size);
394+
// merge entries of both time series until one finishes
395+
while (t1_iterator < t1_size && t2_iterator < t2_size) {
396+
// check which ts has the smaller time at the current iterator, and merge it
397+
if (ts1.get_time(t1_iterator) < ts2.get_time(t2_iterator)) {
398+
merged_ts.add_time_point(ts1.get_time(t1_iterator), ts1.get_value(t1_iterator));
399+
++t1_iterator;
400+
}
401+
else if (ts1.get_time(t1_iterator) == ts2.get_time(t2_iterator)) {
402+
merged_ts.add_time_point(ts1.get_time(t1_iterator), ts1.get_value(t1_iterator));
403+
if (add_values) {
404+
merged_ts.get_last_value() += ts2.get_value(t2_iterator);
420405
}
421-
if (!t2_finished) {
422-
if (ts2.get_time(t2_iterator) < ts1.get_time(t1_iterator) ||
423-
t1_finished) { // Current time point of second TimeSeries is smaller than current time point of first TimeSeries or first TimeSeries has already been copied entirely
424-
merged_ts.add_time_point(ts2.get_time(t2_iterator), ts2.get_value(t2_iterator));
425-
t2_iterator += 1;
426-
}
427-
else if (!t1_finished && ts2.get_time(t2_iterator) ==
428-
ts1.get_time(t1_iterator)) { // Both TimeSeries have the current time point
429-
if (add_values) {
430-
merged_ts.add_time_point(ts1.get_time(t1_iterator),
431-
ts1.get_value(t1_iterator) + ts2.get_value(t2_iterator));
432-
}
433-
else {
434-
merged_ts.add_time_point(ts1.get_time(t1_iterator), ts1.get_value(t1_iterator));
435-
log_warning("Both TimeSeries have values for t={}. The value of the first TimeSeries is used",
436-
ts1.get_time(t1_iterator));
437-
}
438-
t1_iterator += 1;
439-
t2_iterator += 1;
440-
if (t1_iterator >=
441-
ts1.get_num_time_points()) { // Check if all values of first TimeSeries have been copied
442-
t1_finished = true;
443-
t1_iterator = ts1.get_num_time_points() - 1;
444-
}
445-
}
446-
if (t2_iterator >=
447-
ts2.get_num_time_points()) { // Check if all values of second TimeSeries have been copied
448-
t2_finished = true;
449-
t2_iterator = ts2.get_num_time_points() - 1;
450-
}
406+
else {
407+
log_warning("Both TimeSeries have values for t={}. The value of the first TimeSeries is used",
408+
ts1.get_time(t1_iterator));
451409
}
410+
++t1_iterator;
411+
++t2_iterator;
412+
}
413+
else { // " > "
414+
merged_ts.add_time_point(ts2.get_time(t2_iterator), ts2.get_value(t2_iterator));
415+
++t2_iterator;
452416
}
453417
}
418+
// append remaining entries. at most one of the following for loops will be executed
419+
for (; t1_iterator < t1_size; ++t1_iterator) {
420+
merged_ts.add_time_point(ts1.get_time(t1_iterator), ts1.get_value(t1_iterator));
421+
}
422+
for (; t2_iterator < t2_size; ++t2_iterator) {
423+
merged_ts.add_time_point(ts2.get_time(t2_iterator), ts2.get_value(t2_iterator));
424+
}
454425
return success(merged_ts);
455426
}
456427

0 commit comments

Comments
 (0)