forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwav2vec2.cc
More file actions
125 lines (105 loc) · 5.89 KB
/
wav2vec2.cc
File metadata and controls
125 lines (105 loc) · 5.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include "module.h"
#include <ctranslate2/models/wav2vec2.h>
#include "replica_pool.h"
namespace ctranslate2 {
namespace python {
class Wav2Vec2Wrapper : public ReplicaPoolHelper<models::Wav2Vec2> {
public:
using ReplicaPoolHelper::ReplicaPoolHelper;
StorageView encode(const StorageView& features, const bool to_cpu) {
std::shared_lock lock(_mutex);
assert_model_is_ready();
return _pool->encode(features, to_cpu).get();
}
};
void register_wav2vec2(py::module& m) {
py::class_<Wav2Vec2Wrapper>(
m, "Wav2Vec2",
R"pbdoc(
Implements the Wav2Vec2 speech recognition model published by Facebook.
See Also:
https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec
)pbdoc")
.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
py::arg("device_index")=0,
py::arg("compute_type")="default",
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
py::arg("flash_attention")=false,
py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes a Wav2Vec2 model from a converted model.
Arguments:
model_path: Path to the CTranslate2 model directory.
device: Device to use (possible values are: cpu, cuda, auto).
device_index: Device IDs where to place this model on.
compute_type: Model computation type or a dictionary mapping a device name
to the computation type (possible values are: default, auto, int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
inter_threads: Number of workers to allow executing multiple batches in parallel.
intra_threads: Number of OpenMP threads per worker (0 to use a default value).
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
flash_attention: run model with flash attention 2 for self-attention layer
tensor_parallel: run model with tensor parallel mode
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
)pbdoc")
.def_property_readonly("device", &Wav2Vec2Wrapper::device,
"Device this model is running on.")
.def_property_readonly("device_index", &Wav2Vec2Wrapper::device_index,
"List of device IDs where this model is running on.")
.def_property_readonly("compute_type", &Wav2Vec2Wrapper::compute_type,
"Computation type used by the model.")
.def_property_readonly("num_workers", &Wav2Vec2Wrapper::num_replicas,
"Number of model workers backing this instance.")
.def_property_readonly("num_queued_batches", &Wav2Vec2Wrapper::num_queued_batches,
"Number of batches waiting to be processed.")
.def_property_readonly("tensor_parallel", &Wav2Vec2Wrapper::tensor_parallel,
"Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &Wav2Vec2Wrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")
.def("encode", &Wav2Vec2Wrapper::encode,
py::arg("features"),
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Encodes the input features.
Arguments:
features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or
raw audio, as a float array with shape (followed by VAD)
``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]``
to_cpu: Copy the encoder output to the CPU before returning the value.
Returns:
The encoder output.
)pbdoc")
.def("unload_model", &Wav2Vec2Wrapper::unload_model,
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Unloads the model attached to this wav2vec2 but keep enough runtime context
to quickly resume wav2vec2 on the initial device.
Arguments:
to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded.
)pbdoc")
.def("load_model", &Wav2Vec2Wrapper::load_model,
py::arg("keep_cache")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Loads the model back to the initial device.
Arguments:
keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists.
)pbdoc")
.def_property_readonly("model_is_loaded", &Wav2Vec2Wrapper::model_is_loaded,
"Whether the model is loaded on the initial device and ready to be used.")
;
}
}
}