@@ -275,41 +275,57 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
275275 return tokens;
276276}
277277
278+ // TODO: Calculate this constant from the vocabulary
279+ #define MAX_TOKEN_LEN 18
280+ // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
278281std::vector<gpt_vocab::id> llama_tokenize (const gpt_vocab & vocab, const std::string & text, bool bos) {
279- // auto res = gpt_tokenize(vocab, text);
280-
281- // if (bos) {
282- // res.insert(res.begin(), 1); // TODO: replace with vocab.bos
283- // }
284-
285282 std::vector<gpt_vocab::id> res;
286-
287- if (bos) {
288- res.push_back (1 ); // TODO: replace with vocab.bos
289- }
290-
291- // find the longest token that matches the text
292- int pos = 0 ;
293- while (true ) {
294- int l = 0 ;
295- int t = 0 ;
296- for (const auto & kv : vocab.id_to_token ) {
297- if (kv.second .size () < l) continue ;
298- if (kv.second .size () > text.size () - pos) continue ;
299- if (text.substr (pos, kv.second .size ()) == kv.second ) {
300- l = kv.second .size ();
301- t = kv.first ;
283+ std::vector<int > score;
284+ std::vector<gpt_vocab::id> prev;
285+ int len = text.length ();
286+
287+ score.resize (len + 1 );
288+ prev.resize (len + 1 );
289+
290+ // Forward pass
291+ for (int i = 0 ; i < len; i++) {
292+ int max_len = std::min (len - i, MAX_TOKEN_LEN);
293+ for (int sub_len = 1 ; sub_len <= len - i; sub_len++) {
294+ auto sub = text.substr (i, sub_len);
295+ auto token = vocab.token_to_id .find (sub);
296+ if (token != vocab.token_to_id .end ()) {
297+ int token_score = sub.length () * sub.length ();
298+ int local_score = score[i] + token_score;
299+ int next = i + sub_len;
300+ if (score[next] < local_score) {
301+ score[next] = local_score;
302+ prev[next] = (*token).second ;
303+ }
302304 }
303305 }
306+ }
304307
305- if (l == 0 ) {
306- break ;
308+ // Backward pass
309+ int i = len;
310+ while (i > 0 ) {
311+ gpt_vocab::id token_id = prev[i];
312+ if (token_id == 0 ) {
313+ // TODO: Return error or something more meaningful
314+ printf (" failed to tokenize string!\n " );
315+ break ;
307316 }
317+ res.push_back (token_id);
318+ auto token = (*vocab.id_to_token .find (token_id)).second ;
319+ i -= token.length ();
320+ }
308321
309- res. push_back (t);
310- pos += l;
322+ if (bos) {
323+ res. push_back ( 1 ); // TODO: replace with vocab.bos
311324 }
312325
326+ // Pieces are in reverse order so correct that
327+ std::reverse (res.begin (), res.end ());
328+
313329 return res;
314330}
315331
0 commit comments