@@ -2936,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
29362936 /* .language =*/ " en" ,
29372937
29382938 /* .suppress_blank =*/ true ,
2939+ /* .suppress_non_speech_tokens =*/ true ,
29392940
29402941 /* .temperature =*/ 0 .0f ,
29412942 /* .max_initial_ts =*/ 1 .0f ,
@@ -3077,6 +3078,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
30773078 return res;
30783079}
30793080
3081+ static const std::vector<std::string> non_speech_tokens
3082+ {
3083+ " \" " , " #" , " (" , " )" , " *" , " +" , " /" , " :" , " ;" , " <" , " =" , " >" , " @" , " [" , " \\ " , " ]" , " ^" ,
3084+ " _" , " `" , " {" , " |" , " }" , " ~" , " 「" , " 」" , " 『" , " 』" , " <<" , " >>" , " <<<" , " >>>" , " --" ,
3085+ " ---" , " -(" , " -[" , " ('" , " (\" " , " ((" , " ))" , " (((" , " )))" , " [[" , " ]]" , " {{" , " }}" , " ♪♪" ,
3086+ " ♪♪♪" ," ♩" , " ♪" , " ♫" , " ♬" , " ♭" , " ♮" , " ♯"
3087+ };
3088+
30803089// process the logits for the selected decoder
30813090// - applies logit filters
30823091// - computes logprobs and probs
@@ -3137,6 +3146,33 @@ static void whisper_process_logits(
31373146 logits[vocab.token_translate ] = -INFINITY;
31383147 logits[vocab.token_transcribe ] = -INFINITY;
31393148
3149+
3150+ // suppress non-speech tokens
3151+ // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
3152+ if (params.suppress_non_speech_tokens )
3153+ {
3154+ for (const std::string &token : non_speech_tokens)
3155+ {
3156+ std::string suppress_tokens[] = {token, " " + token};
3157+ for (const std::string &suppress_token : suppress_tokens)
3158+ {
3159+ if (vocab.token_to_id .find (suppress_token) != vocab.token_to_id .end ())
3160+ {
3161+ logits[vocab.token_to_id .at (suppress_token)] = -INFINITY;
3162+ }
3163+ }
3164+ }
3165+ // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
3166+ if (vocab.token_to_id .find (" -" ) != vocab.token_to_id .end ())
3167+ {
3168+ logits[vocab.token_to_id .at (" -" )] = -INFINITY;
3169+ }
3170+ if (vocab.token_to_id .find (" '" ) != vocab.token_to_id .end ())
3171+ {
3172+ logits[vocab.token_to_id .at (" '" )] = -INFINITY;
3173+ }
3174+ }
3175+
31403176 // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
31413177 // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
31423178 {
0 commit comments