99from ctypes import CDLL , byref , c_char_p , c_int
1010from typing import Any , Callable , List , Dict , Optional
1111
12+ from runpod .version import __version__ as runpod_version
1213from runpod .serverless .modules .rp_logger import RunPodLogger
13-
14+ from runpod . serverless . modules import rp_job
1415
1516log = RunPodLogger ()
1617
@@ -44,6 +45,7 @@ class Hook: # pylint: disable=too-many-instance-attributes
4445
4546 def __new__ (cls ):
4647 if Hook ._instance is None :
48+ log .debug ("SLS Core | Initializing Hook." )
4749 Hook ._instance = object .__new__ (cls )
4850 Hook ._initialized = False
4951 return Hook ._instance
@@ -136,7 +138,7 @@ def progress_update(self, job_id: str, json_data: bytes) -> bool:
136138 c_char_p (json_data ), c_int (len (json_data ))
137139 ))
138140
139- def stream_output (self , job_id : str , job_output : bytes ) -> bool :
141+ async def stream_output (self , job_id : str , job_output : bytes ) -> bool :
140142 """
141143 send part of a streaming result to AI-API.
142144 """
@@ -170,48 +172,70 @@ def finish_stream(self, job_id: str) -> bool:
170172
171173
172174# -------------------------------- Process Job ------------------------------- #
173- async def _process_job (handler : Callable , job : Dict [str , Any ]) -> Dict [str , Any ]:
175+ async def _process_job (config : Dict [ str , Any ], job : Dict [str , Any ], hook ) -> Dict [str , Any ]:
174176 """ Process a single job. """
175- hook = Hook ()
177+ handler = config [ 'handler' ]
176178
179+ result = {}
177180 try :
178- result = handler (job )
179- except Exception as err :
180- raise RuntimeError (
181- f"run { job ['id' ]} : user code raised an { type (err ).__name__ } " ) from err
181+ if inspect .isgeneratorfunction (handler ) or inspect .isasyncgenfunction (handler ):
182+ log .debug ("SLS Core | Running job as a generator." )
183+ generator_output = rp_job .run_job_generator (handler , job )
184+ aggregated_output = {'output' : []}
185+
186+ async for part in generator_output :
187+ log .debug (f"SLS Core | Streaming output: { part } " , job ['id' ])
188+
189+ if 'error' in part :
190+ aggregated_output = part
191+ break
192+ if config .get ('return_aggregate_stream' , False ):
193+ aggregated_output ['output' ].append (part ['output' ])
194+
195+ await hook .stream_output (job ['id' ], part )
182196
183- if inspect . isgeneratorfunction ( handler ):
184- for part in result :
185- hook . stream_output ( job [ 'id' ], part )
197+ log . debug ( "SLS Core | Finished streaming output." , job [ 'id' ])
198+ hook . finish_stream ( job [ 'id' ])
199+ result = aggregated_output
186200
187- hook .finish_stream (job ['id' ])
201+ else :
202+ log .debug ("SLS Core | Running job as a standard function." )
203+ result = await rp_job .run_job (handler , job )
204+ result = result .get ('output' , result )
205+
206+ except Exception as err : # pylint: disable=broad-except
207+ log .error (f"SLS Core | Error running job: { err } " , job ['id' ])
208+ result = {'error' : str (err )}
188209
189- else :
210+ finally :
211+ log .debug (f"SLS Core | Posting output: { result } " , job ['id' ])
190212 hook .post_output (job ['id' ], result )
191213
192214
193- # -------------------------------- Run Worker -------------------------------- #
215+ # ---------------------------------------------------------------------------- #
216+ # Run Worker #
217+ # ---------------------------------------------------------------------------- #
194218async def run (config : Dict [str , Any ]) -> None :
195219 """ Run the worker.
196220
197221 Args:
198222 config: A dictionary containing the following keys:
199223 handler: A function that takes a job and returns a result.
200224 """
201- handler = config ['handler' ]
202- max_concurrency = config .get ('max_concurrency' , 4 )
203- max_jobs = config .get ('max_jobs' , 4 )
225+ max_concurrency = config .get ('max_concurrency' , 1 )
226+ max_jobs = config .get ('max_jobs' , 1 )
204227
205- hook = Hook ()
228+ serverless_hook = Hook ()
206229
207230 while True :
208- jobs = hook .get_jobs (max_concurrency , max_jobs )
231+ jobs = serverless_hook .get_jobs (max_concurrency , max_jobs )
209232
210233 if len (jobs ) == 0 or jobs is None :
234+ await asyncio .sleep (0 )
211235 continue
212236
213237 for job in jobs :
214- asyncio .create_task (_process_job (handler , job ), name = job ['id' ])
238+ asyncio .create_task (_process_job (config , job , serverless_hook ), name = job ['id' ])
215239 await asyncio .sleep (0 )
216240
217241 await asyncio .sleep (0 )
@@ -220,6 +244,7 @@ async def run(config: Dict[str, Any]) -> None:
220244def main (config : Dict [str , Any ]) -> None :
221245 """Run the worker in an asyncio event loop."""
222246 if config .get ('handler' ) is None :
247+ log .error ("SLS Core | config must contain a handler function" )
223248 raise ValueError ("config must contain a handler function" )
224249
225250 try :
0 commit comments