Skip to content

Commit 2d06d25

Browse files
authored
Add method generate_tokens to return a generator of tokens (OpenNMT#1165)
1 parent 2482153 commit 2d06d25

15 files changed

Lines changed: 360 additions & 6 deletions

docs/conf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def _remove_self(signature):
5050
arguments.pop(0)
5151
return "(%s)" % ", ".join(arguments)
5252

53+
def _remove_private(signature):
54+
index = signature.find(", _")
55+
if index > 0:
56+
signature = signature[:index] + ")"
57+
return signature
58+
5359
def _reformat_typehints(content):
5460
return content.replace(
5561
"ctranslate2._ext.",
@@ -58,6 +64,7 @@ def _reformat_typehints(content):
5864

5965
if signature is not None:
6066
signature = _remove_self(signature)
67+
signature = _remove_private(signature)
6168
signature = _reformat_typehints(signature)
6269

6370
if return_annotation is not None:

include/ctranslate2/decoding.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3+
#include <functional>
4+
#include <optional>
5+
36
#include "ctranslate2/decoding_utils.h"
47
#include "ctranslate2/devices.h"
58
#include "ctranslate2/layers/decoder.h"
@@ -14,6 +17,14 @@ namespace ctranslate2 {
1417
std::vector<std::vector<std::vector<float>>> attention;
1518
};
1619

20+
struct DecodingStepResult {
21+
size_t step;
22+
size_t batch_id;
23+
size_t token_id;
24+
std::optional<float> log_prob;
25+
bool is_last = false;
26+
};
27+
1728

1829
class SearchStrategy {
1930
public:
@@ -91,7 +102,9 @@ namespace ctranslate2 {
91102
class GreedySearch : public SearchStrategy {
92103
public:
93104
// Penalties are only applied to return scores consistent with the beam search.
94-
GreedySearch(const float length_penalty = 0, const float coverage_penalty = 0);
105+
GreedySearch(const float length_penalty = 0,
106+
const float coverage_penalty = 0,
107+
std::function<void(DecodingStepResult)> callback = nullptr);
95108

96109
std::vector<DecodingResult>
97110
search(layers::Decoder& decoder,
@@ -113,6 +126,7 @@ namespace ctranslate2 {
113126
private:
114127
const float _length_penalty;
115128
const float _coverage_penalty;
129+
const std::function<void(DecodingStepResult)> _callback;
116130
};
117131

118132

@@ -140,6 +154,7 @@ namespace ctranslate2 {
140154
std::vector<size_t> disable_ids_begin;
141155
std::vector<std::vector<size_t>> disable_sequences;
142156
std::vector<std::shared_ptr<LogitsProcessor>> logits_processors;
157+
std::function<void(DecodingStepResult)> callback = nullptr;
143158
};
144159

145160
std::vector<DecodingResult>

include/ctranslate2/generation.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@
33
#include <vector>
44
#include <string>
55

6+
#include "decoding.h"
7+
#include "vocabulary.h"
8+
69
namespace ctranslate2 {
710

11+
struct GenerationStepResult;
12+
813
struct GenerationOptions {
914
// Beam size to use for beam search (set 1 to run greedy search).
1015
size_t beam_size = 1;
@@ -51,6 +56,9 @@ namespace ctranslate2 {
5156

5257
// Include the input tokens in the generation result.
5358
bool include_prompt_in_result = true;
59+
60+
// Function to call for each generated token in greedy search.
61+
std::function<void(GenerationStepResult)> callback = nullptr;
5462
};
5563

5664
struct GenerationResult {
@@ -67,4 +75,24 @@ namespace ctranslate2 {
6775
}
6876
};
6977

78+
struct GenerationStepResult {
79+
size_t step;
80+
size_t batch_id;
81+
size_t token_id;
82+
std::string token;
83+
std::optional<float> log_prob;
84+
bool is_last;
85+
86+
GenerationStepResult() = default;
87+
GenerationStepResult(const DecodingStepResult& result, const Vocabulary& vocabulary)
88+
: step(result.step)
89+
, batch_id(result.batch_id)
90+
, token_id(result.token_id)
91+
, token(vocabulary.to_token(result.token_id))
92+
, log_prob(result.log_prob)
93+
, is_last(result.is_last)
94+
{
95+
}
96+
};
97+
7098
}

include/ctranslate2/translation.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <string>
55
#include <vector>
66

7+
#include "generation.h"
8+
79
namespace ctranslate2 {
810

911
struct TranslationOptions {
@@ -70,6 +72,9 @@ namespace ctranslate2 {
7072

7173
// Replace unknown target tokens by the original source token with the highest attention.
7274
bool replace_unknowns = false;
75+
76+
// Function to call for each generated token in greedy search.
77+
std::function<void(GenerationStepResult)> callback = nullptr;
7378
};
7479

7580
struct TranslationResult {

python/cpp/generation_result.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,33 @@ namespace ctranslate2 {
88
namespace python {
99

1010
void register_generation_result(py::module& m) {
11+
py::class_<GenerationStepResult>(m, "GenerationStepResult",
12+
"The result for a single generation step.")
13+
14+
.def_readonly("step", &GenerationStepResult::step,
15+
"The decoding step.")
16+
.def_readonly("batch_id", &GenerationStepResult::batch_id,
17+
"The batch index.")
18+
.def_readonly("token_id", &GenerationStepResult::token_id,
19+
"ID of the generated token.")
20+
.def_readonly("token", &GenerationStepResult::token,
21+
"String value of the generated token.")
22+
.def_readonly("log_prob", &GenerationStepResult::log_prob,
23+
"Log probability of the token (``None`` if :obj:`return_log_prob` was disabled).")
24+
.def_readonly("is_last", &GenerationStepResult::is_last,
25+
"Whether this step is the last decoding step for this batch.")
26+
27+
.def("__repr__", [](const GenerationStepResult& result) {
28+
return "GenerationStepResult(step=" + std::string(py::repr(py::cast(result.step)))
29+
+ ", batch_id=" + std::string(py::repr(py::cast(result.batch_id)))
30+
+ ", token_id=" + std::string(py::repr(py::cast(result.token_id)))
31+
+ ", token=" + std::string(py::repr(py::cast(result.token)))
32+
+ ", log_prob=" + std::string(py::repr(py::cast(result.log_prob)))
33+
+ ", is_last=" + std::string(py::repr(py::cast(result.is_last)))
34+
+ ")";
35+
})
36+
;
37+
1138
py::class_<GenerationResult>(m, "GenerationResult", "A generation result.")
1239

1340
.def_readonly("sequences", &GenerationResult::sequences,

python/cpp/generator.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ namespace ctranslate2 {
3333
bool return_alternatives,
3434
float min_alternative_expansion_prob,
3535
size_t sampling_topk,
36-
float sampling_temperature) {
36+
float sampling_temperature,
37+
std::function<void(GenerationStepResult)> callback) {
3738
if (tokens.empty())
3839
return {};
3940

@@ -54,6 +55,7 @@ namespace ctranslate2 {
5455
options.return_alternatives = return_alternatives;
5556
options.include_prompt_in_result = include_prompt_in_result;
5657
options.min_alternative_expansion_prob = min_alternative_expansion_prob;
58+
options.callback = std::move(callback);
5759
if (suppress_sequences)
5860
options.suppress_sequences = suppress_sequences.value();
5961
if (end_token)
@@ -181,6 +183,7 @@ namespace ctranslate2 {
181183
py::arg("min_alternative_expansion_prob")=0,
182184
py::arg("sampling_topk")=1,
183185
py::arg("sampling_temperature")=1,
186+
py::arg("_callback")=nullptr,
184187
py::call_guard<py::gil_scoped_release>(),
185188
R"pbdoc(
186189
Generates from a batch of start tokens.

python/cpp/replica_pool.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ namespace ctranslate2 {
5858
_pool = std::make_unique<T>(_model_loader, _pool_config);
5959
}
6060

61+
~ReplicaPoolHelper() {
62+
pybind11::gil_scoped_release nogil;
63+
_pool.reset();
64+
}
65+
6166
std::string device() const {
6267
return device_to_str(_model_loader.device);
6368
}

python/cpp/translator.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ namespace ctranslate2 {
170170
float min_alternative_expansion_prob,
171171
size_t sampling_topk,
172172
float sampling_temperature,
173-
bool replace_unknowns) {
173+
bool replace_unknowns,
174+
std::function<void(GenerationStepResult)> callback) {
174175
if (source.empty())
175176
return {};
176177

@@ -196,6 +197,7 @@ namespace ctranslate2 {
196197
options.return_alternatives = return_alternatives;
197198
options.min_alternative_expansion_prob = min_alternative_expansion_prob;
198199
options.replace_unknowns = replace_unknowns;
200+
options.callback = std::move(callback);
199201
if (suppress_sequences)
200202
options.suppress_sequences = suppress_sequences.value();
201203
if (end_token)
@@ -437,6 +439,7 @@ namespace ctranslate2 {
437439
py::arg("sampling_topk")=1,
438440
py::arg("sampling_temperature")=1,
439441
py::arg("replace_unknowns")=false,
442+
py::arg("_callback")=nullptr,
440443
py::call_guard<py::gil_scoped_release>(),
441444
R"pbdoc(
442445
Translates a batch of tokens.

python/ctranslate2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
AsyncTranslationResult,
2525
ExecutionStats,
2626
GenerationResult,
27+
GenerationStepResult,
2728
Generator,
2829
ScoringResult,
2930
StorageView,

0 commit comments

Comments
 (0)