forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmpi.cc
More file actions
30 lines (22 loc) · 824 Bytes
/
mpi.cc
File metadata and controls
30 lines (22 loc) · 824 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include "module.h"
#include <ctranslate2/devices.h>
#include "utils.h"
namespace ctranslate2 {
namespace python {
void register_mpi(py::module& m) {
py::class_<ScopedMPISetter>(
m, "MpiInfo",
R"pbdoc(
An object to manage the MPI communication between processes.
It provides information about MPI connexion.
)pbdoc")
.def_static("getNRanks", &ScopedMPISetter::getNRanks,
"Get the number of gpus running for the current model.")
.def_static("getCurRank", &ScopedMPISetter::getCurRank,
"Get the current rank of process.")
.def_static("getLocalRank", &ScopedMPISetter::getLocalRank,
"Get the current GPU id used by process.")
;
}
}
}