forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgeneration_result.cc
More file actions
68 lines (57 loc) · 3.22 KB
/
generation_result.cc
File metadata and controls
68 lines (57 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include "module.h"
#include <ctranslate2/generation.h>
#include "utils.h"
namespace ctranslate2 {
namespace python {
void register_generation_result(py::module& m) {
py::class_<GenerationStepResult>(m, "GenerationStepResult",
"The result for a single generation step.")
.def_readonly("step", &GenerationStepResult::step,
"The decoding step.")
.def_readonly("batch_id", &GenerationStepResult::batch_id,
"The batch index.")
.def_readonly("token_id", &GenerationStepResult::token_id,
"ID of the generated token.")
.def_readonly("hypothesis_id", &GenerationStepResult::hypothesis_id,
"Index of the hypothesis in the batch.")
.def_readonly("token", &GenerationStepResult::token,
"String value of the generated token.")
.def_readonly("log_prob", &GenerationStepResult::score,
"Log probability of the token (``None`` if :obj:`return_log_prob` was disabled).")
.def_readonly("logits", &GenerationStepResult::logits,
"Log probability on the vocab of all tokens.")
.def_readonly("is_last", &GenerationStepResult::is_last,
"Whether this step is the last decoding step for this batch.")
.def("__repr__", [](const GenerationStepResult& result) {
return "GenerationStepResult(step=" + std::string(py::repr(py::cast(result.step)))
+ ", batch_id=" + std::string(py::repr(py::cast(result.batch_id)))
+ ", token_id=" + std::string(py::repr(py::cast(result.token_id)))
+ ", hypothesis_id=" + std::string(py::repr(py::cast(result.hypothesis_id)))
+ ", token=" + std::string(py::repr(py::cast(result.token)))
+ ", log_prob=" + std::string(py::repr(py::cast(result.score)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ", is_last=" + std::string(py::repr(py::cast(result.is_last)))
+ ")";
})
;
py::class_<GenerationResult>(m, "GenerationResult", "A generation result.")
.def_readonly("sequences", &GenerationResult::sequences,
"Generated sequences of tokens.")
.def_readonly("sequences_ids", &GenerationResult::sequences_ids,
"Generated sequences of token IDs.")
.def_readonly("scores", &GenerationResult::scores,
"Score of each sequence (empty if :obj:`return_scores` was disabled).")
.def_readonly("logits", &GenerationResult::logits,
"Logits of each sequence (empty if :obj:`return_logits_vocab` was disabled).")
.def("__repr__", [](const GenerationResult& result) {
return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences)))
+ ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids)))
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ")";
})
;
declare_async_wrapper<GenerationResult>(m, "AsyncGenerationResult");
}
}
}