forked from ggml-org/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest-reasoning-budget.cpp
More file actions
237 lines (203 loc) · 10.7 KB
/
test-reasoning-budget.cpp
File metadata and controls
237 lines (203 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
#include "reasoning-budget.h"
#include "unicode.h"
#include "llama.h"
#include "ggml.h"
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <cmath>
#include <cstddef>
#include <cstdio>
#include <string>
#include <vector>
// Reasoning budget sampler test helper
// These tests use nullptr vocab which safely falls back to treating all tokens as complete
// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection)
static void test_reasoning_budget(
const char * test_name,
const std::vector<llama_token> & sequence,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state,
size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never)
size_t expected_force_end // token index where forcing should end (after this, no more forcing)
) {
// Find the maximum token ID to ensure our vocab covers all tokens
llama_token max_token = 0;
for (auto t : sequence) max_token = std::max(max_token, t);
for (auto t : start_tokens) max_token = std::max(max_token, t);
for (auto t : end_tokens) max_token = std::max(max_token, t);
for (auto t : forced_tokens) max_token = std::max(max_token, t);
// Create a minimal sampler with mock vocabulary
// For this test, we use nullptr as vocab since we're testing state transitions
// The UTF-8 boundary check will treat all tokens as complete (safe fallback)
auto * sampler = common_reasoning_budget_init(
nullptr, // vocab - not used for basic state machine tests
start_tokens,
end_tokens,
forced_tokens,
budget,
initial_state
);
// Create a test token data array for checking forcing behavior
// Vocab size must be large enough to include all tokens (start, end, forced, sequence)
std::vector<llama_token_data> cur;
const size_t n_vocab = (size_t)max_token + 1;
for (size_t i = 0; i < n_vocab; i++) {
cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
size_t actual_force_start = SIZE_MAX;
size_t actual_force_end = SIZE_MAX;
// Feed the sequence and track when forcing occurs
for (size_t i = 0; i < sequence.size(); i++) {
// Check if we're in forcing state by applying and seeing if logits are modified
cur_p.selected = -1;
for (size_t j = 0; j < cur.size(); j++) {
cur[j].logit = logf((float)(j+1)); // reset logits
}
llama_sampler_apply(sampler, &cur_p);
// Check if forcing is active (all logits except one should be -INFINITY)
size_t finite_count = 0;
llama_token finite_token = -1;
for (size_t j = 0; j < cur.size(); j++) {
if (std::isfinite(cur[j].logit)) {
finite_count++;
finite_token = cur[j].id;
}
}
llama_sampler_accept(sampler, sequence[i]);
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
if (finite_count == 1) {
if (actual_force_start == SIZE_MAX) {
actual_force_start = i;
}
actual_force_end = i;
} else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) {
// Forcing stopped
break;
}
}
llama_sampler_free(sampler);
// Verify forcing occurred at expected positions
if (expected_force_start == SIZE_MAX) {
if (actual_force_start != SIZE_MAX) {
fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start);
GGML_ASSERT(false && "Expected no forcing, but forcing occurred");
}
} else {
if (actual_force_start == SIZE_MAX) {
fprintf(stderr, "Test '%s' FAILED: Expected forcing but none occurred\n", test_name);
GGML_ASSERT(false && "Expected forcing but none occurred");
}
if (actual_force_start != expected_force_start) {
fprintf(stderr, "Test '%s' FAILED: Forcing started at %zu, expected %zu\n", test_name, actual_force_start, expected_force_start);
GGML_ASSERT(false && "Forcing started at wrong position");
}
}
if (expected_force_end != SIZE_MAX) {
if (actual_force_end < expected_force_end) {
fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
GGML_ASSERT(false && "Forcing ended too early");
}
}
fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end);
(void)sequence;
}
// UTF-8 boundary detection unit test
// Tests common_utf8_is_complete() from reasoning-budget.h
static void test_utf8_boundary_detection() {
// Complete sequences
GGML_ASSERT(common_utf8_is_complete("hello"));
GGML_ASSERT(common_utf8_is_complete(""));
GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0)
GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote)
GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji)
GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte
// Incomplete sequences
GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3
GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte
// Mixed: ASCII followed by start of multi-byte
GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte
GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte
}
int main(void) {
// Reasoning budget sampler tests
printf("Testing reasoning budget sampler... ");
// Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted)
{
const std::vector<llama_token> start = {100}; // start token
const std::vector<llama_token> end = {101}; // end token
const std::vector<llama_token> forced = {102}; // forced token (not used in this test)
const std::vector<llama_token> sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more
test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced,
5, // budget of 5 tokens
REASONING_BUDGET_IDLE,
SIZE_MAX, SIZE_MAX); // no forcing expected (natural end)
}
// Test 2: Budget exhausted, forcing should occur
// Flow: i=0 apply()->passthrough, accept(100)->COUNTING; i=1 accept(50)->remaining=1
// i=2 accept(51)->remaining=0->FORCING; i=3 apply() forces token[0]; i=4 apply() forces token[1]
// At i=4, accept() advances force_pos to 2 which equals forced_tokens.size(), so state becomes DONE
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101}; // forced message + end
const std::vector<llama_token> sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2)
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
2, // budget of 2 tokens
REASONING_BUDGET_IDLE,
3, // forcing starts at i=3 (accept at i=2 depletes budget, apply at i=3 forces)
4); // forcing continues through i=4 (accept at i=4 transitions to DONE)
}
// Test 3: Activate immediately with budget=0, forcing should start right away
// Flow: init promotes COUNTING+budget=0 to FORCING, so apply() sees FORCING at i=0
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
const std::vector<llama_token> sequence = {100, 50, 51, 52}; // start token first, then 3 tokens
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
0, // budget of 0 tokens
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
0, // forcing starts at i=0 (initialized in FORCING, apply forces immediately)
1); // forcing continues through i=1 (accept at i=1 transitions to DONE)
}
// Test 4: No start/end tokens configured - passthrough (no forcing)
{
const std::vector<llama_token> start = {};
const std::vector<llama_token> end = {};
const std::vector<llama_token> forced = {102};
const std::vector<llama_token> sequence = {50, 51, 52, 53};
test_reasoning_budget("no start/end configured", sequence, start, end, forced,
2, // budget
REASONING_BUDGET_IDLE,
SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured)
}
// Test 5: Activate immediately with budget > 0, count down then force
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
// Forcing starts at i=2 (apply sees FORCING after accept at i=1 transitioned)
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
const std::vector<llama_token> sequence = {50, 51, 52, 53};
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
2, // budget of 2 tokens
REASONING_BUDGET_COUNTING,
2, // forcing starts at i=2 (after 2 accepts deplete budget, apply at i=2 forces)
3); // forcing continues through i=3
}
printf("OK (5 tests passed)\n");
printf("Testing UTF-8 boundary detection... ");
test_utf8_boundary_detection();
printf("OK\n");
return 0;
}