Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .rp_handler import is_generator
from .rp_job import run_job, run_job_generator
from .rp_ping import Heartbeat
from .worker_state import Jobs
from .worker_state import Job, JobsProgress

RUNPOD_ENDPOINT_ID = os.environ.get("RUNPOD_ENDPOINT_ID", None)

Expand Down Expand Up @@ -96,7 +96,7 @@


# ------------------------------ Initializations ----------------------------- #
job_list = Jobs()
job_list = JobsProgress()
heartbeat = Heartbeat()


Expand Down Expand Up @@ -286,12 +286,12 @@ async def _realtime(self, job: Job):
Performs model inference on the input data using the provided handler.
If handler is not provided, returns an error message.
"""
job_list.add_job(job.id)
job_list.add(job.id)

# Process the job using the provided handler, passing in the job input.
job_results = await run_job(self.config["handler"], job.__dict__)

job_list.remove_job(job.id)
job_list.remove(job.id)

# Return the results of the job processing.
return jsonable_encoder(job_results)
Expand All @@ -304,7 +304,11 @@ async def _realtime(self, job: Job):
async def _sim_run(self, job_request: DefaultRequest) -> JobOutput:
"""Development endpoint to simulate run behavior."""
assigned_job_id = f"test-{uuid.uuid4()}"
job_list.add_job(assigned_job_id, job_request.input, job_request.webhook)
job_list.add({
"id": assigned_job_id,
"input": job_request.input,
"webhook": job_request.webhook
})
return jsonable_encoder({"id": assigned_job_id, "status": "IN_PROGRESS"})

# ---------------------------------- runsync --------------------------------- #
Expand Down Expand Up @@ -341,7 +345,7 @@ async def _sim_runsync(self, job_request: DefaultRequest) -> JobOutput:
# ---------------------------------- stream ---------------------------------- #
async def _sim_stream(self, job_id: str) -> StreamOutput:
"""Development endpoint to simulate stream behavior."""
stashed_job = job_list.get_job(job_id)
stashed_job = job_list.get(job_id)
if stashed_job is None:
return jsonable_encoder(
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
Expand All @@ -363,7 +367,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
}
)

job_list.remove_job(job.id)
job_list.remove(job.id)

if stashed_job.webhook:
thread = threading.Thread(
Expand All @@ -380,7 +384,7 @@ async def _sim_stream(self, job_id: str) -> StreamOutput:
# ---------------------------------- status ---------------------------------- #
async def _sim_status(self, job_id: str) -> JobOutput:
"""Development endpoint to simulate status behavior."""
stashed_job = job_list.get_job(job_id)
stashed_job = job_list.get(job_id)
if stashed_job is None:
return jsonable_encoder(
{"id": job_id, "status": "FAILED", "error": "Job ID not found"}
Expand All @@ -396,7 +400,7 @@ async def _sim_status(self, job_id: str) -> JobOutput:
else:
job_output = await run_job(self.config["handler"], job.__dict__)

job_list.remove_job(job.id)
job_list.remove(job.id)

if job_output.get("error", None):
return jsonable_encoder(
Expand Down
50 changes: 47 additions & 3 deletions runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
from runpod.serverless.modules.rp_logger import RunPodLogger

from ...version import __version__ as runpod_version
from ..utils import rp_debugger
from .rp_handler import is_generator
from .rp_http import send_result, stream_result
from .rp_tips import check_return_size
from .worker_state import WORKER_ID, JobsQueue
from .worker_state import WORKER_ID, REF_COUNT_ZERO, JobsProgress

JOB_GET_URL = str(os.environ.get("RUNPOD_WEBHOOK_GET_JOB")).replace("$ID", WORKER_ID)

log = RunPodLogger()
job_list = JobsQueue()
job_progress = JobsProgress()


def _job_get_url(batch_size: int = 1):
Expand All @@ -32,7 +35,7 @@ def _job_get_url(batch_size: int = 1):
Returns:
str: The prepared URL for the 'get' request to the serverless API.
"""
job_in_progress = "1" if job_list.get_job_count() else "0"
job_in_progress = "1" if job_progress.get_job_count() else "0"

if batch_size > 1:
job_take_url = JOB_GET_URL.replace("/job-take/", "/job-take-batch/")
Expand Down Expand Up @@ -96,6 +99,47 @@ async def get_job(
return []


async def handle_job(session: ClientSession, config: Dict[str, Any], job) -> dict:
if is_generator(config["handler"]):
is_stream = True
generator_output = run_job_generator(config["handler"], job)
log.debug("Handler is a generator, streaming results.", job["id"])

job_result = {"output": []}
async for stream_output in generator_output:
log.debug(f"Stream output: {stream_output}", job["id"])
if "error" in stream_output:
job_result = stream_output
break
if config.get("return_aggregate_stream", False):
job_result["output"].append(stream_output["output"])

await stream_result(session, stream_output, job)
else:
is_stream = False
job_result = await run_job(config["handler"], job)

# If refresh_worker is set, pod will be reset after job is complete.
if config.get("refresh_worker", False):
log.info("refresh_worker flag set, stopping pod after job.", job["id"])
job_result["stopPod"] = True

# If rp_debugger is set, debugger output will be returned.
if config["rp_args"].get("rp_debugger", False) and isinstance(job_result, dict):
job_result["output"]["rp_debugger"] = rp_debugger.get_debugger_output()
log.debug("rp_debugger | Flag set, returning debugger output.", job["id"])

# Calculate ready delay for the debugger output.
ready_delay = (config["reference_counter_start"] - REF_COUNT_ZERO) * 1000
job_result["output"]["rp_debugger"]["ready_delay_ms"] = ready_delay
else:
log.debug("rp_debugger | Flag not set, skipping debugger output.", job["id"])
rp_debugger.clear_debugger_output()

# Send the job result back to JOB_DONE_URL
await send_result(session, job_result, job, is_stream=is_stream)


async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
"""
Run the job using the handler.
Expand Down
18 changes: 9 additions & 9 deletions runpod/serverless/modules/rp_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,26 @@

from runpod.http_client import SyncClientSession
from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.serverless.modules.worker_state import WORKER_ID, JobsQueue
from runpod.serverless.modules.worker_state import WORKER_ID, JobsProgress
from runpod.version import __version__ as runpod_version

log = RunPodLogger()
jobs = JobsQueue() # Contains the list of jobs that are currently running.
jobs = JobsProgress() # Contains the list of jobs that are currently running.


class Heartbeat:
"""Sends heartbeats to the Runpod server."""

PING_URL = os.environ.get("RUNPOD_WEBHOOK_PING", "PING_NOT_SET")
PING_URL = PING_URL.replace("$RUNPOD_POD_ID", WORKER_ID)
PING_INTERVAL = int(os.environ.get("RUNPOD_PING_INTERVAL", 10000)) // 1000

_thread_started = False

def __init__(self, pool_connections=10, retries=3) -> None:
"""
Initializes the Heartbeat class.
"""
self.PING_URL = os.environ.get("RUNPOD_WEBHOOK_PING", "PING_NOT_SET")
self.PING_URL = self.PING_URL.replace("$RUNPOD_POD_ID", WORKER_ID)
self.PING_INTERVAL = int(os.environ.get("RUNPOD_PING_INTERVAL", 10000)) // 1000

self._session = SyncClientSession()
self._session.headers.update(
{"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"}
Expand All @@ -56,15 +56,15 @@ def start_ping(self, test=False):
"""
Sends heartbeat pings to the Runpod server.
"""
if os.environ.get("RUNPOD_AI_API_KEY") is None:
if not os.environ.get("RUNPOD_AI_API_KEY"):
log.debug("Not deployed on RunPod serverless, pings will not be sent.")
return

if os.environ.get("RUNPOD_POD_ID") is None:
if not os.environ.get("RUNPOD_POD_ID"):
log.info("Not running on RunPod, pings will not be sent.")
return

if self.PING_URL in ["PING_NOT_SET", None]:
if (not self.PING_URL) or self.PING_URL == "PING_NOT_SET":
log.error("Ping URL not set, cannot start ping.")
return

Expand Down
71 changes: 22 additions & 49 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
from typing import Any, Dict

from ...http_client import ClientSession
from ..utils import rp_debugger
from .rp_handler import is_generator
from .rp_http import send_result, stream_result
from .rp_job import get_job, run_job, run_job_generator
from .rp_job import get_job, handle_job
from .rp_logger import RunPodLogger
from .worker_state import JobsQueue, REF_COUNT_ZERO
from .worker_state import JobsQueue, JobsProgress

log = RunPodLogger()
job_list = JobsQueue()
job_progress = JobsProgress()


def _default_concurrency_modifier(current_concurrency: int) -> int:
Expand Down Expand Up @@ -68,15 +66,15 @@ async def get_jobs(self, session: ClientSession):
Adds jobs to the JobsQueue
"""
while self.is_alive():
log.debug(f"Jobs in progress: {job_list.get_job_count()}")
log.debug(f"Jobs in progress: {job_progress.get_job_count()}")

try:
self.current_concurrency = self.concurrency_modifier(
self.current_concurrency
)
log.debug(f"Concurrency set to: {self.current_concurrency}")

jobs_needed = self.current_concurrency - job_list.get_job_count()
jobs_needed = self.current_concurrency - job_progress.get_job_count()
if not jobs_needed: # zero or less
log.debug("Queue is full. Retrying soon.")
continue
Expand Down Expand Up @@ -113,7 +111,7 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
job = await job_list.get_job()

# Create a new task for each job and add it to the task list
task = asyncio.create_task(self.process_job(session, config, job))
task = asyncio.create_task(self.handle_job(session, config, job))
tasks.append(task)

# Wait for any job to finish
Expand All @@ -133,51 +131,26 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
# Ensure all remaining tasks finish before stopping
await asyncio.gather(*tasks)

async def process_job(self, session: ClientSession, config: Dict[str, Any], job):
async def handle_job(self, session: ClientSession, config: Dict[str, Any], job):
"""
Process an individual job. This function is run concurrently for multiple jobs.
"""
log.debug(f"Processing job: {job}")
job_progress.add(job)

if is_generator(config["handler"]):
is_stream = True
generator_output = run_job_generator(config["handler"], job)
log.debug("Handler is a generator, streaming results.", job["id"])

job_result = {"output": []}
async for stream_output in generator_output:
log.debug(f"Stream output: {stream_output}", job["id"])
if "error" in stream_output:
job_result = stream_output
break
if config.get("return_aggregate_stream", False):
job_result["output"].append(stream_output["output"])

await stream_result(session, stream_output, job)
else:
is_stream = False
job_result = await run_job(config["handler"], job)

# If refresh_worker is set, pod will be reset after job is complete.
if config.get("refresh_worker", False):
log.info("refresh_worker flag set, stopping pod after job.", job["id"])
job_result["stopPod"] = True
self.kill_worker()

# If rp_debugger is set, debugger output will be returned.
if config["rp_args"].get("rp_debugger", False) and isinstance(job_result, dict):
job_result["output"]["rp_debugger"] = rp_debugger.get_debugger_output()
log.debug("rp_debugger | Flag set, returning debugger output.", job["id"])

# Calculate ready delay for the debugger output.
ready_delay = (config["reference_counter_start"] - REF_COUNT_ZERO) * 1000
job_result["output"]["rp_debugger"]["ready_delay_ms"] = ready_delay
else:
log.debug("rp_debugger | Flag not set, skipping debugger output.", job["id"])
rp_debugger.clear_debugger_output()
try:
await handle_job(session, config, job)

if config.get("refresh_worker", False):
self.kill_worker()

except Exception as err:
log.error(f"Error handling job: {err}", job["id"])
raise err

# Send the job result back to JOB_DONE_URL
await send_result(session, job_result, job, is_stream=is_stream)
finally:
# Inform JobsQueue of a task completion
job_list.task_done()

# Inform JobsQueue of a task completion
job_list.task_done()
# Job is no longer in progress
job_progress.remove(job["id"])
Loading