11from abc import ABC , abstractmethod
22from typing import Any , AsyncIterator , Callable , Dict , Optional , Union
33
4- from fastapi import APIRouter
4+ from fastapi import APIRouter , Request
55from litellm import ModelResponse
66from litellm .types .llms .openai import ChatCompletionRequest
77
8+ from codegate .codegate_logging import setup_logging
89from codegate .pipeline .base import PipelineResult , SequentialPipelineProcessor
910from codegate .providers .completion .base import BaseCompletionHandler
1011from codegate .providers .formatting .input_pipeline import PipelineResponseFormatter
1112from codegate .providers .normalizer .base import ModelInputNormalizer , ModelOutputNormalizer
1213
14+ logger = setup_logging ()
1315StreamGenerator = Callable [[AsyncIterator [Any ]], AsyncIterator [str ]]
1416
1517
@@ -25,12 +27,14 @@ def __init__(
2527 output_normalizer : ModelOutputNormalizer ,
2628 completion_handler : BaseCompletionHandler ,
2729 pipeline_processor : Optional [SequentialPipelineProcessor ] = None ,
30+ fim_pipeline_processor : Optional [SequentialPipelineProcessor ] = None ,
2831 ):
2932 self .router = APIRouter ()
3033 self ._completion_handler = completion_handler
3134 self ._input_normalizer = input_normalizer
3235 self ._output_normalizer = output_normalizer
3336 self ._pipeline_processor = pipeline_processor
37+ self ._fim_pipelin_processor = fim_pipeline_processor
3438
3539 self ._pipeline_response_formatter = PipelineResponseFormatter (output_normalizer )
3640
@@ -48,22 +52,76 @@ def provider_route_name(self) -> str:
4852 async def _run_input_pipeline (
4953 self ,
5054 normalized_request : ChatCompletionRequest ,
55+ is_fim_request : bool
5156 ) -> PipelineResult :
52- if self ._pipeline_processor is None :
57+ # Decide which pipeline processor to use
58+ if is_fim_request :
59+ pipeline_processor = self ._fim_pipelin_processor
60+ logger .info ('FIM pipeline selected for execution.' )
61+ else :
62+ pipeline_processor = self ._pipeline_processor
63+ logger .info ('Chat completion pipeline selected for execution.' )
64+ if pipeline_processor is None :
5365 return PipelineResult (request = normalized_request )
5466
55- result = await self . _pipeline_processor .process_request (normalized_request )
67+ result = await pipeline_processor .process_request (normalized_request )
5668
5769 # TODO(jakub): handle this by returning a message to the client
5870 if result .error_message :
5971 raise Exception (result .error_message )
6072
6173 return result
6274
75+ def _is_fim_request_url (self , request : Request ) -> bool :
76+ """
77+ Checks the request URL to determine if a request is FIM or chat completion.
78+ Used by: llama.cpp
79+ """
80+ request_path = request .url .path
81+ # Evaluate first a larger substring.
82+ if request_path .endswith ("/chat/completions" ):
83+ return False
84+
85+ if request_path .endswith ("/completions" ):
86+ return True
87+
88+ return False
89+
90+ def _is_fim_request_body (self , data : Dict ) -> bool :
91+ """
92+ Determine from the raw incoming data if it's a FIM request.
93+ Used by: OpenAI and Anthropic
94+ """
95+ messages = data .get ('messages' , [])
96+ if not messages :
97+ return False
98+
99+ first_message_content = messages [0 ].get ('content' )
100+ if first_message_content is None :
101+ return False
102+
103+ fim_stop_sequences = ['</COMPLETION>' , '<COMPLETION>' , '</QUERY>' , '<QUERY>' ]
104+ if isinstance (first_message_content , str ):
105+ msg_prompt = first_message_content
106+ elif isinstance (first_message_content , list ):
107+ msg_prompt = first_message_content [0 ].get ('text' , '' )
108+ else :
109+ logger .warning (f'Could not determine if message was FIM from data: { data } ' )
110+ return False
111+ return all ([stop_sequence in msg_prompt for stop_sequence in fim_stop_sequences ])
112+
113+ def _is_fim_request (self , request : Request , data : Dict ) -> bool :
114+ """
115+ Determin if the request is FIM by the URL or the data of the request.
116+ """
117+ # Avoid more expensive inspection of body by just checking the URL.
118+ if self ._is_fim_request_url (request ):
119+ return True
120+
121+ return self ._is_fim_request_body (data )
122+
63123 async def complete (
64- self ,
65- data : Dict ,
66- api_key : Optional [str ],
124+ self , data : Dict , api_key : Optional [str ], is_fim_request : bool
67125 ) -> Union [ModelResponse , AsyncIterator [ModelResponse ]]:
68126 """
69127 Main completion flow with pipeline integration
@@ -79,7 +137,7 @@ async def complete(
79137 normalized_request = self ._input_normalizer .normalize (data )
80138 streaming = data .get ("stream" , False )
81139
82- input_pipeline_result = await self ._run_input_pipeline (normalized_request )
140+ input_pipeline_result = await self ._run_input_pipeline (normalized_request , is_fim_request )
83141 if input_pipeline_result .response :
84142 return self ._pipeline_response_formatter .handle_pipeline_response (
85143 input_pipeline_result .response , streaming
0 commit comments