forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwhisper.cc
More file actions
370 lines (318 loc) · 17.3 KB
/
whisper.cc
File metadata and controls
370 lines (318 loc) · 17.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
#include "module.h"
#include <ctranslate2/models/whisper.h>
#include "replica_pool.h"
namespace ctranslate2 {
namespace python {
class WhisperWrapper : public ReplicaPoolHelper<models::Whisper> {
public:
using ReplicaPoolHelper::ReplicaPoolHelper;
bool is_multilingual() const {
return _pool->is_multilingual();
}
size_t n_mels() const {
return _pool->n_mels();
}
size_t num_languages() const {
return _pool->num_languages();
}
StorageView encode(const StorageView& features, const bool to_cpu) {
return _pool->encode(features, to_cpu).get();
}
std::variant<std::vector<models::WhisperGenerationResult>,
std::vector<AsyncResult<models::WhisperGenerationResult>>>
generate(const StorageView& features,
std::variant<BatchTokens, BatchIds> prompts,
bool asynchronous,
size_t beam_size,
float patience,
size_t num_hypotheses,
float length_penalty,
float repetition_penalty,
size_t no_repeat_ngram_size,
size_t max_length,
bool return_scores,
bool return_logits_vocab,
bool return_no_speech_prob,
size_t max_initial_timestamp_index,
bool suppress_blank,
const std::optional<std::vector<int>>& suppress_tokens,
size_t sampling_topk,
float sampling_temperature) {
std::vector<std::future<models::WhisperGenerationResult>> futures;
models::WhisperOptions options;
options.beam_size = beam_size;
options.patience = patience;
options.length_penalty = length_penalty;
options.repetition_penalty = repetition_penalty;
options.no_repeat_ngram_size = no_repeat_ngram_size;
options.sampling_topk = sampling_topk;
options.sampling_temperature = sampling_temperature;
options.max_length = max_length;
options.num_hypotheses = num_hypotheses;
options.return_scores = return_scores;
options.return_logits_vocab = return_logits_vocab;
options.return_no_speech_prob = return_no_speech_prob;
options.max_initial_timestamp_index = max_initial_timestamp_index;
options.suppress_blank = suppress_blank;
if (suppress_tokens)
options.suppress_tokens = suppress_tokens.value();
else
options.suppress_tokens.clear();
std::shared_lock lock(_mutex);
assert_model_is_ready();
if (prompts.index() == 0)
futures = _pool->generate(features, std::get<BatchTokens>(prompts), options);
else
futures = _pool->generate(features, std::get<BatchIds>(prompts), options);
return maybe_wait_on_futures(std::move(futures), asynchronous);
}
std::vector<std::vector<std::pair<std::string, float>>>
detect_language(const StorageView& features) {
std::shared_lock lock(_mutex);
assert_model_is_ready();
auto futures = _pool->detect_language(features);
return wait_on_futures(std::move(futures));
}
std::vector<models::WhisperAlignmentResult>
align(const StorageView& features,
Ids start_sequence,
BatchIds text_tokens,
const std::variant<size_t, std::vector<size_t>>& num_frames,
size_t median_filter_width) {
const size_t batch_size = text_tokens.size();
std::vector<size_t> batch_num_frames;
if (num_frames.index() == 0)
batch_num_frames.resize(batch_size, std::get<size_t>(num_frames));
else
batch_num_frames = std::get<std::vector<size_t>>(num_frames);
std::shared_lock lock(_mutex);
assert_model_is_ready();
auto futures = _pool->align(features,
std::move(start_sequence),
std::move(text_tokens),
std::move(batch_num_frames),
median_filter_width);
return wait_on_futures(std::move(futures));
}
};
void register_whisper(py::module& m) {
py::class_<models::WhisperGenerationResult>(m, "WhisperGenerationResult",
"A generation result from the Whisper model.")
.def_readonly("sequences", &models::WhisperGenerationResult::sequences,
"Generated sequences of tokens.")
.def_readonly("sequences_ids", &models::WhisperGenerationResult::sequences_ids,
"Generated sequences of token IDs.")
.def_readonly("scores", &models::WhisperGenerationResult::scores,
"Score of each sequence (empty if :obj:`return_scores` was disabled).")
.def_readonly("logits", &models::WhisperGenerationResult::logits,
"logits in each sequence (empty if :obj:`return_logits_vocab` was disabled).")
.def_readonly("no_speech_prob", &models::WhisperGenerationResult::no_speech_prob,
"Probability of the no speech token (0 if :obj:`return_no_speech_prob` was disabled).")
.def("__repr__", [](const models::WhisperGenerationResult& result) {
return "WhisperGenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences)))
+ ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids)))
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
+ ", no_speech_prob=" + std::string(py::repr(py::cast(result.no_speech_prob)))
+ ")";
})
;
declare_async_wrapper<models::WhisperGenerationResult>(m, "WhisperGenerationResultAsync");
py::class_<models::WhisperAlignmentResult>(m, "WhisperAlignmentResult",
"An alignment result from the Whisper model.")
.def_readonly("alignments", &models::WhisperAlignmentResult::alignments,
"List of aligned text and time indices.")
.def_readonly("text_token_probs", &models::WhisperAlignmentResult::text_token_probs,
"Probabilities of text tokens.")
.def("__repr__", [](const models::WhisperAlignmentResult& result) {
return "WhisperAlignmentResult(alignments=" + std::string(py::repr(py::cast(result.alignments)))
+ ", text_token_probs=" + std::string(py::repr(py::cast(result.text_token_probs)))
+ ")";
})
;
py::class_<WhisperWrapper>(
m, "Whisper",
R"pbdoc(
Implements the Whisper speech recognition model published by OpenAI.
See Also:
https://github.com/openai/whisper
)pbdoc")
.def_property_readonly("is_multilingual", &WhisperWrapper::is_multilingual,
"Returns ``True`` if this model is multilingual.")
.def_property_readonly("n_mels", &WhisperWrapper::n_mels,
"Returns dimension of mel input features.")
.def_property_readonly("num_languages", &WhisperWrapper::num_languages,
"Returns the number of languages supported.")
.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
py::arg("device_index")=0,
py::arg("compute_type")="default",
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
py::arg("flash_attention")=false,
py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes a Whisper model from a converted model.
Arguments:
model_path: Path to the CTranslate2 model directory.
device: Device to use (possible values are: cpu, cuda, auto).
device_index: Device IDs where to place this model on.
compute_type: Model computation type or a dictionary mapping a device name
to the computation type (possible values are: default, auto, int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
inter_threads: Number of workers to allow executing multiple batches in parallel.
intra_threads: Number of OpenMP threads per worker (0 to use a default value).
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
flash_attention: run model with flash attention 2 for self-attention layer
tensor_parallel: run model with tensor parallel mode
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
)pbdoc")
.def_property_readonly("device", &WhisperWrapper::device,
"Device this model is running on.")
.def_property_readonly("device_index", &WhisperWrapper::device_index,
"List of device IDs where this model is running on.")
.def_property_readonly("compute_type", &WhisperWrapper::compute_type,
"Computation type used by the model.")
.def_property_readonly("num_workers", &WhisperWrapper::num_replicas,
"Number of model workers backing this instance.")
.def_property_readonly("num_queued_batches", &WhisperWrapper::num_queued_batches,
"Number of batches waiting to be processed.")
.def_property_readonly("tensor_parallel", &WhisperWrapper::tensor_parallel,
"Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &WhisperWrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")
.def("encode", &WhisperWrapper::encode,
py::arg("features"),
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Encodes the input features.
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, n_mels, chunk_length]``.
to_cpu: Copy the encoder output to the CPU before returning the value.
Returns:
The encoder output.
)pbdoc")
.def("generate", &WhisperWrapper::generate,
py::arg("features"),
py::arg("prompts"),
py::kw_only(),
py::arg("asynchronous")=false,
py::arg("beam_size")=5,
py::arg("patience")=1,
py::arg("num_hypotheses")=1,
py::arg("length_penalty")=1,
py::arg("repetition_penalty")=1,
py::arg("no_repeat_ngram_size")=0,
py::arg("max_length")=448,
py::arg("return_scores")=false,
py::arg("return_logits_vocab")=false,
py::arg("return_no_speech_prob")=false,
py::arg("max_initial_timestamp_index")=50,
py::arg("suppress_blank")=true,
py::arg("suppress_tokens")=std::vector<int>{-1},
py::arg("sampling_topk")=1,
py::arg("sampling_temperature")=1,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Encodes the input features and generates from the given prompt.
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, n_mels, chunk_length]``. This method also accepts the encoded
features returned by the method :meth:`ctranslate2.models.Whisper.encode`,
which have shape ``[batch_size, chunk_length // 2, d_model]``.
prompts: Batch of initial string tokens or token IDs.
asynchronous: Run the model asynchronously.
beam_size: Beam size (1 for greedy search).
patience: Beam search patience factor, as described in
https://arxiv.org/abs/2204.05424. The decoding will continue until
beam_size*patience hypotheses are finished.
num_hypotheses: Number of hypotheses to return.
length_penalty: Exponential penalty applied to the length during beam search.
repetition_penalty: Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: Prevent repetitions of ngrams with this size
(set 0 to disable).
max_length: Maximum generation length.
return_scores: Include the scores in the output.
return_logits_vocab: Include the log probs in the output
return_no_speech_prob: Include the probability of the no speech token in the
result.
max_initial_timestamp_index: Maximum index of the first predicted timestamp.
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in the model ``config.json`` file.
sampling_topk: Randomly sample predictions from the top K candidates.
sampling_temperature: Sampling temperature to generate more random samples.
Returns:
A list of generation results.
)pbdoc")
.def("detect_language", &WhisperWrapper::detect_language,
py::arg("features"),
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Returns the probability of each language.
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, n_mels, chunk_length]``. This method also accepts the encoded
features returned by the method :meth:`ctranslate2.models.Whisper.encode`,
which have shape ``[batch_size, chunk_length // 2, d_model]``.
Returns:
For each batch, a list of pairs (language, probability) ordered from
best to worst probability.
Raises:
RuntimeError: if the model is not multilingual.
)pbdoc")
.def("align", &WhisperWrapper::align,
py::arg("features"),
py::arg("start_sequence"),
py::arg("text_tokens"),
py::arg("num_frames"),
py::kw_only(),
py::arg("median_filter_width")=7,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Computes the alignments between the text tokens and the audio.
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, n_mels, chunk_length]``. This method also accepts the encoded
features returned by the method :meth:`ctranslate2.models.Whisper.encode`,
which have shape ``[batch_size, chunk_length // 2, d_model]``.
start_sequence: The start sequence tokens.
text_tokens: Batch of text tokens to align.
num_frames: Number of non padding frames in the features.
median_filter_width: Width of the median filter kernel.
Returns:
A list of alignment results.
)pbdoc")
.def("unload_model", &WhisperWrapper::unload_model,
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Unloads the model attached to this whisper but keep enough runtime context
to quickly resume whisper on the initial device.
Arguments:
to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded.
)pbdoc")
.def("load_model", &WhisperWrapper::load_model,
py::arg("keep_cache")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Loads the model back to the initial device.
Arguments:
keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists.
)pbdoc")
.def_property_readonly("model_is_loaded", &WhisperWrapper::model_is_loaded,
"Whether the model is loaded on the initial device and ready to be used.")
;
}
}
}