Skip to content

Commit c796281

Browse files
reneSchmjubicker
andauthored
1467 Make timers MPI compatible (#1468)
- Add TimerRegistrar::print_gathered_timers and supporting methods. - Update example and docs Co-authored-by: jubicker <[email protected]>
1 parent a4a7b77 commit c796281

12 files changed

Lines changed: 251 additions & 39 deletions

File tree

cpp/examples/performance_timers.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
*/
2020
#include "memilio/timer/auto_timer.h"
2121
#include "memilio/timer/table_printer.h"
22+
#include "memilio/timer/timer_registrar.h"
23+
#include "memilio/utils/logging.h"
24+
#include "memilio/utils/miompi.h"
2225

2326
#include <thread> // This is only used for the example load function.
2427

@@ -40,8 +43,15 @@ void load()
4043

4144
int main()
4245
{
43-
// Specify the namespace of AutoTimer, so we don't have to repeat it. Do not do this with entire namespaces.
46+
mio::set_log_level(mio::LogLevel::info);
47+
mio::mpi::init();
48+
49+
const int comm_rank = mio::mpi::rank(mio::mpi::get_world());
50+
const int comm_size = mio::mpi::size(mio::mpi::get_world());
51+
// Specify the namespace of AutoTimer and TimerRegistrar, so we don't have to repeat it.
52+
// Avoid "using" statements with entire namespaces.
4453
using mio::timing::AutoTimer;
54+
using mio::timing::TimerRegistrar;
4555

4656
// Start immediately timing the main function. An AutoTimer starts timing upon its creation, and ends timing when
4757
// it is destroyed. This usually happens when a function returns, or a scope indicated by {curly braces} ends.
@@ -54,15 +64,21 @@ int main()
5464
// the end of the programm. This can be disabled by calling `TimerRegistrar::disable_final_timer_summary()`.
5565
auto printer = std::make_unique<mio::timing::TablePrinter>();
5666
printer->set_time_format("{:e}");
57-
mio::timing::TimerRegistrar::get_instance().set_printer(std::move(printer));
67+
TimerRegistrar::get_instance().set_printer(std::move(printer));
5868

59-
// To manually print all timers, use `TimerRegistrar::print_timers()`, but make sure that no timers are running.
69+
// To manually print all timers, use `TimerRegistrar::print_timers()` (or `TimerRegistrar::print_gathered_timers()`
70+
// when using MPI), but make sure that no timers are running.
6071
// The "main" timer in this example would make that difficult, but you can simply add another scope around it,
6172
// similar to the "compute loops" timer below.
6273

63-
const int N = 1000; // Number of iterations. Be wary of increasing this number when using the sleep_for load!
74+
// Number of iterations. Be wary of increasing this number when using the sleep_for load!
75+
const int total_iterations = 1000;
76+
const int N = total_iterations / comm_size + (comm_rank < (total_iterations % comm_size));
6477

65-
std::cout << "Num threads: " << mio::omp::get_max_threads() << "\n";
78+
if (mio::mpi::is_root()) {
79+
mio::log_info("Num ranks: {} - Num threads: {} - Work size per rank: {}", comm_size,
80+
mio::omp::get_max_threads(), N);
81+
}
6682

6783
// Open a new scope to time computations.
6884
{
@@ -100,5 +116,15 @@ int main()
100116
}
101117
}
102118

119+
// For MPI parallel execution, we replace the final timer summary with a manual print. That way, we can gather
120+
// timers from all ranks into a single output, instead of each rank printing a separate output. You can also use a
121+
// custom printer for computing additional statistics.
122+
if (comm_size > 1) {
123+
TimerRegistrar::get_instance().disable_final_timer_summary();
124+
TimerRegistrar::get_instance().print_gathered_timers(mio::mpi::get_world());
125+
}
126+
127+
mio::mpi::finalize();
128+
103129
return 0;
104130
}

cpp/memilio/timer/basic_timer.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
#ifndef MIO_TIMER_BASIC_TIMER_H
2121
#define MIO_TIMER_BASIC_TIMER_H
2222

23+
#include "memilio/io/io.h"
2324
#include "memilio/timer/definitions.h"
2425

2526
#include <string_view>
27+
#include <utility>
2628

2729
namespace mio
2830
{
@@ -69,6 +71,37 @@ class BasicTimer
6971
should_be_running(false, "~BasicTimer");
7072
}
7173

74+
/**
75+
* serialize this.
76+
* @see mio::serialize
77+
*/
78+
template <class IOContext>
79+
void serialize(IOContext& io) const
80+
{
81+
auto obj = io.create_object("BasicTimer");
82+
obj.add_element("elapsed_time", details::convert_to_ticks(m_elapsed_time));
83+
}
84+
85+
/**
86+
* deserialize an object of this class.
87+
* @see mio::deserialize
88+
*/
89+
template <class IOContext>
90+
static IOResult<BasicTimer> deserialize(IOContext& io)
91+
{
92+
using Tick = decltype(details::convert_to_ticks(std::declval<DurationType>()));
93+
auto obj = io.expect_object("BasicTimer");
94+
auto et = obj.expect_element("elapsed_time", Tag<Tick>{});
95+
return apply(
96+
io,
97+
[](auto&& et_) {
98+
BasicTimer b;
99+
b.m_elapsed_time = DurationType{et_};
100+
return b;
101+
},
102+
et);
103+
}
104+
72105
private:
73106
TimeType m_start_time; ///< The last time point at which start() was called
74107
DurationType m_elapsed_time{0}; ///< The total time spent between starts and stops.

cpp/memilio/timer/definitions.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "memilio/utils/mioomp.h"
2424

2525
#include <chrono>
26-
#include <cstddef>
26+
#include <ctime>
2727

2828
namespace mio
2929
{
@@ -42,6 +42,21 @@ using DurationType = std::chrono::steady_clock::duration;
4242

4343
#endif
4444

45+
namespace details
46+
{
47+
48+
/// @brief Convert a duration to integer ticks. Useful for serialization.
49+
inline decltype(auto) convert_to_ticks(DurationType duration)
50+
{
51+
#ifdef MEMILIO_ENABLE_OPENMP
52+
return duration;
53+
#else
54+
return duration.count();
55+
#endif
56+
}
57+
58+
} // namespace details
59+
4560
/**
4661
* @brief Convert a duration to a (floating point) number of seconds.
4762
* @param[in] duration Any DurationType value, mainly `BasicTimer::get_elapsed_time()`.

cpp/memilio/timer/list_printer.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#define MIO_TIMER_LIST_PRINTER_H
2222

2323
#include "memilio/timer/registration.h"
24+
#include "memilio/utils/compiler_diagnostics.h"
2425

2526
#include <ostream>
2627
#include <list>
@@ -45,16 +46,16 @@ struct ListPrinter : public Printer {
4546
const auto indent = " ";
4647
// Write out all timers
4748
out << "All Timers: " << timer_register.size() << "\n";
48-
for (const auto& [name, scope, timer, thread] : timer_register) {
49+
for (const auto& [name, scope, timer, thread, rank] : timer_register) {
4950
out << indent << qualified_name(name, scope) << ": " << std::scientific
50-
<< time_in_seconds(timer.get_elapsed_time()) << " (" << thread << ")\n";
51-
is_multithreaded |= thread > 0;
51+
<< time_in_seconds(timer.get_elapsed_time()) << " (" << rank << ", " << thread << ")\n";
52+
is_multithreaded |= thread > 0 || rank > 0;
5253
}
5354
// Write out timers accumulated over threads by name
5455
if (is_multithreaded) {
5556
// dedupe list entries from parallel execution
5657
std::map<std::string, DurationType> deduper;
57-
for (const auto& [name, scope, timer, _] : timer_register) {
58+
for (const auto& [name, scope, timer, thread, rank] : timer_register) {
5859
deduper[qualified_name(name, scope)] += timer.get_elapsed_time();
5960
}
6061
out << "Unique Timers (accumulated): " << deduper.size() << "\n";

cpp/memilio/timer/named_timer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class NamedTimer : public BasicTimer
7777
*/
7878
NamedTimer()
7979
{
80-
TimerRegistrar::get_instance().add_timer({name(), scope(), *this, mio::omp::get_thread_id()});
80+
TimerRegistrar::get_instance().add_timer({name(), scope(), *this, mio::omp::get_thread_id(), 0});
8181
}
8282
};
8383

cpp/memilio/timer/registration.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#ifndef MIO_TIMER_REGISTRATION_H
2121
#define MIO_TIMER_REGISTRATION_H
2222

23+
#include "memilio/io/default_serialize.h"
2324
#include "memilio/timer/basic_timer.h"
2425

2526
#include <list>
@@ -30,11 +31,13 @@ namespace mio
3031
namespace timing
3132
{
3233

34+
/// @brief Struct used to register ( @see TimerRegistrar ) and print timers ( @see Printer ).
3335
struct TimerRegistration {
3436
std::string name;
3537
std::string scope;
3638
BasicTimer& timer;
3739
int thread_id;
40+
int rank;
3841
};
3942

4043
/**

cpp/memilio/timer/table_printer.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ class TablePrinter : public Printer
206206
std::map<std::string, size_t> row_to_index; // map from name to index, used to fill table
207207
bool is_multithreaded = false; // keep track of whether a thread id > 0 exists
208208
// map rows from thread 0 first, so the order of timers (mostly) corresponds to their call order
209-
for (const auto& [name, scope, _, thread] : timer_register) {
210-
if (thread == 0) {
209+
for (const auto& [name, scope, timer, thread, rank] : timer_register) {
210+
if (thread == 0 && rank == 0) {
211211
const std::string qn = qualified_name(name, scope);
212212
if (row_to_index.emplace(qn, rows.size()).second) {
213213
rows.push_back(qn);
@@ -220,8 +220,8 @@ class TablePrinter : public Printer
220220
// make a second pass to add timers from other threads
221221
// this does nothing, if all timers are used on thread 0 at least once
222222
if (is_multithreaded) {
223-
for (auto& [name, scope, _, thread] : timer_register) {
224-
if (thread != 0) {
223+
for (auto& [name, scope, timer, thread, rank] : timer_register) {
224+
if (thread != 0 || rank != 0) {
225225
const std::string qn = qualified_name(name, scope);
226226
if (row_to_index.emplace(qn, rows.size()).second) {
227227
rows.push_back(qn);
@@ -245,7 +245,7 @@ class TablePrinter : public Printer
245245
}
246246
// accumulate elapsed time and gather statistics in the table
247247
// averages are calculated later, using finished values from elapsed and num
248-
for (auto& [name, scope, timer, thread] : timer_register) {
248+
for (auto& [name, scope, timer, thread, rank] : timer_register) {
249249
const auto row = row_to_index[qualified_name(name, scope)];
250250
const auto time = time_in_seconds(timer.get_elapsed_time());
251251
table(row, elapsed) += time;

cpp/memilio/timer/timer_registrar.h

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,19 @@
2020
#ifndef MIO_TIMER_TIMER_REGISTRAR_H
2121
#define MIO_TIMER_TIMER_REGISTRAR_H
2222

23+
#include "memilio/io/binary_serializer.h"
24+
#include "memilio/io/default_serialize.h"
25+
#include "memilio/timer/basic_timer.h"
2326
#include "memilio/timer/registration.h"
2427
#include "memilio/timer/table_printer.h"
28+
#include "memilio/utils/compiler_diagnostics.h"
29+
#include "memilio/utils/miompi.h"
2530

31+
#include <cassert>
2632
#include <iostream>
2733
#include <list>
2834
#include <memory>
2935
#include <mutex>
30-
#include <ostream>
3136

3237
namespace mio
3338
{
@@ -88,14 +93,35 @@ class TimerRegistrar
8893
{
8994
PRAGMA_OMP(single nowait)
9095
{
91-
get_instance().m_printer->print(m_register, out);
96+
m_printer->print(m_register, out);
97+
}
98+
}
99+
100+
/**
101+
* @brief Print all timers gathered on the given MPI Comm using a Printer.
102+
*
103+
* By default, uses TablePrinter to write to stdout. The printer can be changed using the set_printer member.
104+
*
105+
* @param comm An MPI communicator. Make sure this method is called by all ranks on comm!
106+
* @param out Any output stream, defaults to std::cout.
107+
*/
108+
void print_gathered_timers(mpi::Comm comm, std::ostream& out = std::cout)
109+
{
110+
PRAGMA_OMP(single nowait)
111+
{
112+
std::list<BasicTimer> external_timers; // temporary storage for gathered registrations
113+
std::list<TimerRegistration> gathered_register; // combined registrations of all ranks
114+
gather_timers(comm, external_timers, gathered_register);
115+
if (mpi::is_root()) {
116+
m_printer->print(gathered_register, out);
117+
}
92118
}
93119
}
94120

95121
/// @brief Prevent the TimerRegistrar from calling print_timers on exit from main.
96-
void disable_final_timer_summary() const
122+
void disable_final_timer_summary()
97123
{
98-
get_instance().m_print_on_death = false;
124+
m_print_on_death = false;
99125
}
100126

101127
/**
@@ -143,6 +169,62 @@ class TimerRegistrar
143169
}
144170
}
145171

172+
/// @brief Gather timers from all ranks, using external_timers as temporary timer storage for gathered_register.
173+
void gather_timers(mpi::Comm comm, std::list<BasicTimer>& external_timers,
174+
std::list<TimerRegistration>& gathered_register) const
175+
{
176+
#ifndef MEMILIO_ENABLE_MPI
177+
mio::unused(comm, external_timers);
178+
#endif
179+
if (mpi::is_root()) {
180+
std::ranges::transform(m_register, std::back_inserter(gathered_register), [](const TimerRegistration& r) {
181+
return TimerRegistration{r.name, r.scope, r.timer, r.thread_id, 0};
182+
});
183+
}
184+
if (comm == nullptr) {
185+
if (mpi::is_root()) {
186+
log_error("Got nullptr as MPI Comm. Only timers on root rank are gathered.");
187+
}
188+
return;
189+
}
190+
#ifdef MEMILIO_ENABLE_MPI
191+
// name, scope, timer, thread_id
192+
using GatherableRegistration = std::tuple<std::string, std::string, BasicTimer, int>;
193+
int comm_size;
194+
MPI_Comm_size(comm, &comm_size);
195+
196+
if (mpi::is_root()) {
197+
for (int snd_rank = 1; snd_rank < comm_size; snd_rank++) { // skip root rank!
198+
int bytes_size;
199+
MPI_Recv(&bytes_size, 1, MPI_INT, snd_rank, 0, comm, MPI_STATUS_IGNORE);
200+
ByteStream bytes(bytes_size);
201+
MPI_Recv(bytes.data(), bytes.data_size(), MPI_BYTE, snd_rank, 0, mpi::get_world(), MPI_STATUS_IGNORE);
202+
203+
auto rec_register = deserialize_binary(bytes, Tag<std::vector<GatherableRegistration>>{});
204+
if (!rec_register) {
205+
log_error("Error receiving ensemble results from rank {}.", snd_rank);
206+
}
207+
std::ranges::transform(
208+
rec_register.value(), std::back_inserter(gathered_register), [&](const GatherableRegistration& r) {
209+
const auto& [name, scope, timer, thread_id] = r;
210+
external_timers.push_back(timer);
211+
return TimerRegistration{name, scope, external_timers.back(), thread_id, snd_rank};
212+
});
213+
}
214+
}
215+
else {
216+
std::vector<GatherableRegistration> snd_register;
217+
std::ranges::transform(m_register, std::back_inserter(snd_register), [](const TimerRegistration& r) {
218+
return GatherableRegistration{r.name, r.scope, r.timer, r.thread_id};
219+
});
220+
ByteStream bytes = serialize_binary(snd_register);
221+
int bytes_size = int(bytes.data_size());
222+
MPI_Send(&bytes_size, 1, MPI_INT, 0, 0, comm);
223+
MPI_Send(bytes.data(), bytes.data_size(), MPI_BYTE, 0, 0, comm);
224+
}
225+
#endif
226+
}
227+
146228
std::unique_ptr<Printer> m_printer = std::make_unique<TablePrinter>(); ///< A printer to visualize all timers.
147229
bool m_print_on_death = true; ///< Whether to call m_printer during the destructor.
148230
std::list<TimerRegistration> m_register; ///< List that allows access to timers without having their name.

0 commit comments

Comments
 (0)