forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreplica_pool.h
More file actions
197 lines (158 loc) · 6.6 KB
/
replica_pool.h
File metadata and controls
197 lines (158 loc) · 6.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#pragma once
#include <shared_mutex>
#include <ctranslate2/replica_pool.h>
#include "utils.h"
namespace ctranslate2 {
namespace python {
inline std::shared_ptr<models::ModelReader>
create_model_reader(const std::string& model, py::object files) {
if (files.is_none())
return std::make_shared<models::ModelFileReader>(model);
if (!py::isinstance<py::dict>(files))
throw pybind11::type_error("files argument must be a dictionary mapping file names "
"to the file contents");
auto reader = std::make_shared<models::ModelMemoryReader>(model);
for (const auto& pair : files.cast<py::dict>()) {
auto filename = pair.first;
auto content = pair.second;
auto read = py::getattr(content, "read", py::none());
if (!read.is_none())
content = read();
else if (!py::isinstance<py::bytes>(content))
throw pybind11::type_error("File content must be a file-like or bytes object");
reader->register_file(filename.cast<std::string>(), content.cast<std::string>());
}
return reader;
}
template <typename T>
class ReplicaPoolHelper {
public:
ReplicaPoolHelper(const std::string& model_path,
const std::string& device,
const std::variant<int, std::vector<int>>& device_index,
const StringOrMap& compute_type,
size_t inter_threads,
size_t intra_threads,
long max_queued_batches,
bool flash_attention,
bool tensor_parallel,
py::object files)
: _model_loader(create_model_reader(model_path, files))
, _device(str_to_device(device))
, _num_replicas_per_device(inter_threads)
{
pybind11::gil_scoped_release nogil;
_model_loader.device = str_to_device(device);
_model_loader.device_indices = std::visit(DeviceIndexResolver(), device_index);
_model_loader.compute_type = std::visit(ComputeTypeResolver(device), compute_type);
_model_loader.num_replicas_per_device = inter_threads;
_model_loader.use_flash_attention = flash_attention;
_model_loader.tensor_parallel = tensor_parallel;
_pool_config.num_threads_per_replica = intra_threads;
_pool_config.max_queued_batches = max_queued_batches;
_pool = std::make_unique<T>(_model_loader, _pool_config);
_device_index = _model_loader.device_indices;
_model_is_loaded = true;
}
~ReplicaPoolHelper() {
pybind11::gil_scoped_release nogil;
_pool.reset();
}
std::string device() const {
return device_to_str(_model_loader.device);
}
const std::vector<int>& device_index() const {
return _model_loader.device_indices;
}
std::string compute_type() const {
return compute_type_to_str(model()->effective_compute_type());
}
bool tensor_parallel() const {
return _model_loader.tensor_parallel;
}
size_t num_replicas() const {
return _pool->num_replicas();
}
size_t num_queued_batches() const {
return _pool->num_queued_batches();
}
size_t num_active_batches() const {
return _pool->num_active_batches();
}
bool model_is_loaded() {
std::shared_lock lock(_mutex);
return _model_is_loaded;
}
void unload_model(const bool to_cpu) {
if (to_cpu && _device == Device::CPU)
return;
// Do not unload the model if some batches are still being processed.
if (_pool->num_active_batches() > 0)
return;
// If the lock is not acquired immediately it means the model is being used
// in another thread and we can't unload it at this time.
std::unique_lock lock(_mutex, std::try_to_lock);
if (!lock)
return;
std::vector<std::shared_ptr<const models::Model>> loaded_models;
if (_model_is_loaded)
loaded_models = _pool->detach_models();
if (to_cpu && _cached_models.empty())
_cached_models = clone_models(Device::CPU, std::vector<int>(loaded_models.size(), 0), loaded_models);
else if (!to_cpu)
_cached_models.clear();
loaded_models.clear();
// We clear the CUDA allocator cache to further reduce the memory after unloading the model.
if (_device == Device::CUDA)
_pool->clear_cache();
_model_is_loaded = false;
}
void load_model(const bool keep_cache) {
std::unique_lock lock(_mutex);
if (_model_is_loaded)
return;
std::vector<std::shared_ptr<const models::Model>> loaded_models;
if (_cached_models.empty())
loaded_models = _model_loader.load();
else
loaded_models = clone_models(_device, _device_index, _cached_models, _num_replicas_per_device);
_pool->set_models(loaded_models);
if (!keep_cache)
_cached_models.clear();
_model_is_loaded = true;
}
protected:
std::unique_ptr<T> _pool;
models::ModelLoader _model_loader;
ReplicaPoolConfig _pool_config;
const Device _device;
const size_t _num_replicas_per_device;
std::vector<int> _device_index;
std::vector<std::shared_ptr<const models::Model>> _cached_models;
bool _model_is_loaded;
// Use a shared mutex to protect the model state (loaded/unloaded).
// Multiple threads can read the model at the same time, but a single thread can change
// the model state (e.g. load or unload the model).
std::shared_mutex _mutex;
std::vector<std::shared_ptr<const models::Model>> clone_models(Device device,
const std::vector<int>& device_index,
std::vector<std::shared_ptr<const models::Model>> cached_models,
size_t num_models_per_device = 1) {
std::vector<std::shared_ptr<const models::Model>> copied_models;
for (size_t i = 0; i < cached_models.size(); ++i) {
auto& model = const_cast<models::Model&>(*cached_models[i]);
auto copied_model = model.copy_to(device, device_index[i / num_models_per_device]);
copied_models.push_back(copied_model);
}
return copied_models;
}
const std::shared_ptr<const models::Model>& model() const {
return _pool->get_first_replica().model();
}
void assert_model_is_ready() const {
if (!_model_is_loaded)
throw std::runtime_error("The model for this translator was unloaded");
}
};
}
}