|
20 | 20 | #ifndef MIO_TIMER_TIMER_REGISTRAR_H |
21 | 21 | #define MIO_TIMER_TIMER_REGISTRAR_H |
22 | 22 |
|
| 23 | +#include "memilio/io/binary_serializer.h" |
| 24 | +#include "memilio/io/default_serialize.h" |
| 25 | +#include "memilio/timer/basic_timer.h" |
23 | 26 | #include "memilio/timer/registration.h" |
24 | 27 | #include "memilio/timer/table_printer.h" |
| 28 | +#include "memilio/utils/compiler_diagnostics.h" |
| 29 | +#include "memilio/utils/miompi.h" |
25 | 30 |
|
| 31 | +#include <cassert> |
26 | 32 | #include <iostream> |
27 | 33 | #include <list> |
28 | 34 | #include <memory> |
29 | 35 | #include <mutex> |
30 | | -#include <ostream> |
31 | 36 |
|
32 | 37 | namespace mio |
33 | 38 | { |
@@ -88,14 +93,35 @@ class TimerRegistrar |
88 | 93 | { |
89 | 94 | PRAGMA_OMP(single nowait) |
90 | 95 | { |
91 | | - get_instance().m_printer->print(m_register, out); |
| 96 | + m_printer->print(m_register, out); |
| 97 | + } |
| 98 | + } |
| 99 | + |
| 100 | + /** |
| 101 | + * @brief Print all timers gathered on the given MPI Comm using a Printer. |
| 102 | + * |
| 103 | + * By default, uses TablePrinter to write to stdout. The printer can be changed using the set_printer member. |
| 104 | + * |
| 105 | + * @param comm An MPI communicator. Make sure this method is called by all ranks on comm! |
| 106 | + * @param out Any output stream, defaults to std::cout. |
| 107 | + */ |
| 108 | + void print_gathered_timers(mpi::Comm comm, std::ostream& out = std::cout) |
| 109 | + { |
| 110 | + PRAGMA_OMP(single nowait) |
| 111 | + { |
| 112 | + std::list<BasicTimer> external_timers; // temporary storage for gathered registrations |
| 113 | + std::list<TimerRegistration> gathered_register; // combined registrations of all ranks |
| 114 | + gather_timers(comm, external_timers, gathered_register); |
| 115 | + if (mpi::is_root()) { |
| 116 | + m_printer->print(gathered_register, out); |
| 117 | + } |
92 | 118 | } |
93 | 119 | } |
94 | 120 |
|
95 | 121 | /// @brief Prevent the TimerRegistrar from calling print_timers on exit from main. |
96 | | - void disable_final_timer_summary() const |
| 122 | + void disable_final_timer_summary() |
97 | 123 | { |
98 | | - get_instance().m_print_on_death = false; |
| 124 | + m_print_on_death = false; |
99 | 125 | } |
100 | 126 |
|
101 | 127 | /** |
@@ -143,6 +169,62 @@ class TimerRegistrar |
143 | 169 | } |
144 | 170 | } |
145 | 171 |
|
| 172 | + /// @brief Gather timers from all ranks, using external_timers as temporary timer storage for gathered_register. |
| 173 | + void gather_timers(mpi::Comm comm, std::list<BasicTimer>& external_timers, |
| 174 | + std::list<TimerRegistration>& gathered_register) const |
| 175 | + { |
| 176 | +#ifndef MEMILIO_ENABLE_MPI |
| 177 | + mio::unused(comm, external_timers); |
| 178 | +#endif |
| 179 | + if (mpi::is_root()) { |
| 180 | + std::ranges::transform(m_register, std::back_inserter(gathered_register), [](const TimerRegistration& r) { |
| 181 | + return TimerRegistration{r.name, r.scope, r.timer, r.thread_id, 0}; |
| 182 | + }); |
| 183 | + } |
| 184 | + if (comm == nullptr) { |
| 185 | + if (mpi::is_root()) { |
| 186 | + log_error("Got nullptr as MPI Comm. Only timers on root rank are gathered."); |
| 187 | + } |
| 188 | + return; |
| 189 | + } |
| 190 | +#ifdef MEMILIO_ENABLE_MPI |
| 191 | + // name, scope, timer, thread_id |
| 192 | + using GatherableRegistration = std::tuple<std::string, std::string, BasicTimer, int>; |
| 193 | + int comm_size; |
| 194 | + MPI_Comm_size(comm, &comm_size); |
| 195 | + |
| 196 | + if (mpi::is_root()) { |
| 197 | + for (int snd_rank = 1; snd_rank < comm_size; snd_rank++) { // skip root rank! |
| 198 | + int bytes_size; |
| 199 | + MPI_Recv(&bytes_size, 1, MPI_INT, snd_rank, 0, comm, MPI_STATUS_IGNORE); |
| 200 | + ByteStream bytes(bytes_size); |
| 201 | + MPI_Recv(bytes.data(), bytes.data_size(), MPI_BYTE, snd_rank, 0, mpi::get_world(), MPI_STATUS_IGNORE); |
| 202 | + |
| 203 | + auto rec_register = deserialize_binary(bytes, Tag<std::vector<GatherableRegistration>>{}); |
| 204 | + if (!rec_register) { |
| 205 | + log_error("Error receiving ensemble results from rank {}.", snd_rank); |
| 206 | + } |
| 207 | + std::ranges::transform( |
| 208 | + rec_register.value(), std::back_inserter(gathered_register), [&](const GatherableRegistration& r) { |
| 209 | + const auto& [name, scope, timer, thread_id] = r; |
| 210 | + external_timers.push_back(timer); |
| 211 | + return TimerRegistration{name, scope, external_timers.back(), thread_id, snd_rank}; |
| 212 | + }); |
| 213 | + } |
| 214 | + } |
| 215 | + else { |
| 216 | + std::vector<GatherableRegistration> snd_register; |
| 217 | + std::ranges::transform(m_register, std::back_inserter(snd_register), [](const TimerRegistration& r) { |
| 218 | + return GatherableRegistration{r.name, r.scope, r.timer, r.thread_id}; |
| 219 | + }); |
| 220 | + ByteStream bytes = serialize_binary(snd_register); |
| 221 | + int bytes_size = int(bytes.data_size()); |
| 222 | + MPI_Send(&bytes_size, 1, MPI_INT, 0, 0, comm); |
| 223 | + MPI_Send(bytes.data(), bytes.data_size(), MPI_BYTE, 0, 0, comm); |
| 224 | + } |
| 225 | +#endif |
| 226 | + } |
| 227 | + |
146 | 228 | std::unique_ptr<Printer> m_printer = std::make_unique<TablePrinter>(); ///< A printer to visualize all timers. |
147 | 229 | bool m_print_on_death = true; ///< Whether to call m_printer during the destructor. |
148 | 230 | std::list<TimerRegistration> m_register; ///< List that allows access to timers without having their name. |
|
0 commit comments