diff --git a/polyapi/api.py b/polyapi/api.py index fc243ba..5b0ffde 100644 --- a/polyapi/api.py +++ b/polyapi/api.py @@ -35,6 +35,25 @@ def {function_name}( return {api_response_type}(resp.json()) # type: ignore +async def {function_name}_async( +{args} +) -> {api_response_type}: + \"""{function_description} + + Function ID: {function_id} + \""" + if get_direct_execute_config(): + resp = await direct_execute_async("{function_type}", "{function_id}", {data}) + return {api_response_type}({{ + "status": resp.status_code, + "headers": dict(resp.headers), + "data": resp.json() + }}) # type: ignore + else: + resp = await execute_async("{function_type}", "{function_id}", {data}) + return {api_response_type}(resp.json()) # type: ignore + + """ diff --git a/polyapi/auth.py b/polyapi/auth.py index 3d6c325..66a07e0 100644 --- a/polyapi/auth.py +++ b/polyapi/auth.py @@ -117,6 +117,16 @@ def introspectToken(token: str) -> AuthFunctionResponse: url = "/auth-providers/{function_id}/introspect" resp = execute_post(url, {{"token": token}}) return resp.json() + + +async def introspectToken_async(token: str) -> AuthFunctionResponse: + \"""{description} + + Function ID: {function_id} + \""" + url = "/auth-providers/{function_id}/introspect" + resp = await execute_post_async(url, {{"token": token}}) + return resp.json() """ REFRESH_TOKEN_TEMPLATE = """ @@ -128,6 +138,16 @@ def refreshToken(token: str) -> AuthFunctionResponse: url = "/auth-providers/{function_id}/refresh" resp = execute_post(url, {{"token": token}}) return resp.json() + + +async def refreshToken_async(token: str) -> AuthFunctionResponse: + \"""{description} + + Function ID: {function_id} + \""" + url = "/auth-providers/{function_id}/refresh" + resp = await execute_post_async(url, {{"token": token}}) + return resp.json() """ REVOKE_TOKEN_TEMPLATE = """ @@ -142,6 +162,19 @@ def revokeToken(token: str) -> Optional[AuthFunctionResponse]: return resp.json() except: return None + + +async def revokeToken_async(token: str) -> Optional[AuthFunctionResponse]: + \"""{description} + + Function ID: {function_id} + \""" + url = "/auth-providers/{function_id}/revoke" + resp = await execute_post_async(url, {{"token": token}}) + try: + return resp.json() + except: + return None """ diff --git a/polyapi/execute.py b/polyapi/execute.py index 5d75048..d0bf9e3 100644 --- a/polyapi/execute.py +++ b/polyapi/execute.py @@ -1,109 +1,224 @@ -from typing import Dict, Optional -import requests +import httpx import os import logging -from requests import Response from polyapi.config import get_api_key_and_url, get_mtls_config from polyapi.exceptions import PolyApiException +from polyapi import http_client logger = logging.getLogger("poly") -def direct_execute(function_type, function_id, data) -> Response: - """ execute a specific function id/type - """ - api_key, api_url = get_api_key_and_url() - headers = {"Authorization": f"Bearer {api_key}"} - url = f"{api_url}/functions/{function_type}/{function_id}/direct-execute" - - endpoint_info = requests.post(url, json=data, headers=headers) - if endpoint_info.status_code < 200 or endpoint_info.status_code >= 300: - error_content = endpoint_info.content.decode("utf-8", errors="ignore") +def _check_response_error(resp, function_type, function_id, data): + if resp.status_code < 200 or resp.status_code >= 300: + error_content = resp.content.decode("utf-8", errors="ignore") if function_type == 'api' and os.getenv("LOGS_ENABLED"): - raise PolyApiException(f"Error executing api function with id: {function_id}. Status code: {endpoint_info.status_code}. Request data: {data}, Response: {error_content}") + logger.error(f"Error executing api function with id: {function_id}. Status code: {resp.status_code}. Request data: {data}, Response: {error_content}") elif function_type != 'api': - raise PolyApiException(f"{endpoint_info.status_code}: {error_content}") - - endpoint_info_data = endpoint_info.json() + raise PolyApiException(f"{resp.status_code}: {error_content}") + + +def _check_endpoint_error(resp, function_type, function_id, data): + if resp.status_code < 200 or resp.status_code >= 300: + error_content = resp.content.decode("utf-8", errors="ignore") + if function_type == 'api' and os.getenv("LOGS_ENABLED"): + raise PolyApiException(f"Error executing api function with id: {function_id}. Status code: {resp.status_code}. Request data: {data}, Response: {error_content}") + elif function_type != 'api': + raise PolyApiException(f"{resp.status_code}: {error_content}") + + +def _build_direct_execute_params(endpoint_info_data): request_params = endpoint_info_data.copy() request_params.pop("url", None) - if "maxRedirects" in request_params: - request_params["allow_redirects"] = request_params.pop("maxRedirects") > 0 - + request_params["follow_redirects"] = request_params.pop("maxRedirects") > 0 + return request_params + + +def _sync_direct_execute(function_type, function_id, data) -> httpx.Response: + api_key, api_url = get_api_key_and_url() + headers = {"Authorization": f"Bearer {api_key}"} + url = f"{api_url}/functions/{function_type}/{function_id}/direct-execute" + + endpoint_info = http_client.post(url, json=data, headers=headers) + _check_endpoint_error(endpoint_info, function_type, function_id, data) + + endpoint_info_data = endpoint_info.json() + request_params = _build_direct_execute_params(endpoint_info_data) + has_mtls, cert_path, key_path, ca_path = get_mtls_config() - + + # Direct-execute hits URL that may need custom TLS + # settings (mTLS certs or disabled verification). httpx Client.request() + # doesn't accept per-request transport kwargs, so use one-off calls. if has_mtls: - resp = requests.request( + resp = httpx.request( url=endpoint_info_data["url"], cert=(cert_path, key_path), verify=ca_path, + timeout=None, **request_params ) else: - resp = requests.request( + resp = httpx.request( url=endpoint_info_data["url"], verify=False, + timeout=None, **request_params ) - if (resp.status_code < 200 or resp.status_code >= 300): - error_content = resp.content.decode("utf-8", errors="ignore") - if function_type == 'api' and os.getenv("LOGS_ENABLED"): - logger.error(f"Error executing api function with id: {function_id}. Status code: {resp.status_code}. Request data: {data}, Response: {error_content}") - elif function_type != 'api': - raise PolyApiException(f"{resp.status_code}: {error_content}") - + _check_response_error(resp, function_type, function_id, data) + return resp + + +async def _async_direct_execute(function_type, function_id, data) -> httpx.Response: + api_key, api_url = get_api_key_and_url() + headers = {"Authorization": f"Bearer {api_key}"} + url = f"{api_url}/functions/{function_type}/{function_id}/direct-execute" + + endpoint_info = await http_client.async_post(url, json=data, headers=headers) + _check_endpoint_error(endpoint_info, function_type, function_id, data) + + endpoint_info_data = endpoint_info.json() + request_params = _build_direct_execute_params(endpoint_info_data) + + has_mtls, cert_path, key_path, ca_path = get_mtls_config() + + # One-off async client for custom TLS settings on external URLs. + if has_mtls: + async with httpx.AsyncClient( + cert=(cert_path, key_path), verify=ca_path, timeout=None + ) as client: + resp = await client.request( + url=endpoint_info_data["url"], **request_params + ) + else: + async with httpx.AsyncClient(verify=False, timeout=None) as client: + resp = await client.request( + url=endpoint_info_data["url"], **request_params + ) + + _check_response_error(resp, function_type, function_id, data) return resp -def execute(function_type, function_id, data) -> Response: - """ execute a specific function id/type + +def direct_execute(function_type, function_id, data) -> httpx.Response: + """ execute a specific function id/type (sync) """ + return _sync_direct_execute(function_type, function_id, data) + + +async def direct_execute_async(function_type, function_id, data) -> httpx.Response: + """ execute a specific function id/type (async) + """ + return await _async_direct_execute(function_type, function_id, data) + + +def _sync_execute(function_type, function_id, data) -> httpx.Response: api_key, api_url = get_api_key_and_url() headers = {"Authorization": f"Bearer {api_key}"} + url = f"{api_url}/functions/{function_type}/{function_id}/execute" + resp = http_client.post(url, json=data, headers=headers) + _check_response_error(resp, function_type, function_id, data) + return resp + + +async def _async_execute(function_type, function_id, data) -> httpx.Response: + api_key, api_url = get_api_key_and_url() + headers = {"Authorization": f"Bearer {api_key}"} url = f"{api_url}/functions/{function_type}/{function_id}/execute" - - # Make the request - resp = requests.post( - url, - json=data, - headers=headers, - ) - - if (resp.status_code < 200 or resp.status_code >= 300) and os.getenv("LOGS_ENABLED"): - error_content = resp.content.decode("utf-8", errors="ignore") - if function_type == 'api' and os.getenv("LOGS_ENABLED"): - logger.error(f"Error executing api function with id: {function_id}. Status code: {resp.status_code}. Request data: {data}, Response: {error_content}") - elif function_type != 'api': - raise PolyApiException(f"{resp.status_code}: {error_content}") + resp = await http_client.async_post(url, json=data, headers=headers) + _check_response_error(resp, function_type, function_id, data) return resp -def execute_post(path, data): +def execute(function_type, function_id, data) -> httpx.Response: + """ execute a specific function id/type (sync) + """ + return _sync_execute(function_type, function_id, data) + + +async def execute_async(function_type, function_id, data) -> httpx.Response: + """ execute a specific function id/type (async) + """ + return await _async_execute(function_type, function_id, data) + + +def _sync_execute_post(path, data): + api_key, api_url = get_api_key_and_url() + headers = {"Authorization": f"Bearer {api_key}"} + return http_client.post(api_url + path, json=data, headers=headers) + + +async def _async_execute_post(path, data): api_key, api_url = get_api_key_and_url() headers = {"Authorization": f"Bearer {api_key}"} - resp = requests.post(api_url + path, json=data, headers=headers) + return await http_client.async_post(api_url + path, json=data, headers=headers) + + +def execute_post(path, data): + return _sync_execute_post(path, data) + + +async def execute_post_async(path, data): + return await _async_execute_post(path, data) + + +def _sync_variable_get(variable_id: str) -> httpx.Response: + api_key, base_url = get_api_key_and_url() + headers = {"Authorization": f"Bearer {api_key}"} + url = f"{base_url}/variables/{variable_id}/value" + resp = http_client.get(url, headers=headers) + if resp.status_code != 200 and resp.status_code != 201: + error_content = resp.content.decode("utf-8", errors="ignore") + raise PolyApiException(f"{resp.status_code}: {error_content}") return resp -def variable_get(variable_id: str) -> Response: +async def _async_variable_get(variable_id: str) -> httpx.Response: api_key, base_url = get_api_key_and_url() headers = {"Authorization": f"Bearer {api_key}"} url = f"{base_url}/variables/{variable_id}/value" - resp = requests.get(url, headers=headers) + resp = await http_client.async_get(url, headers=headers) if resp.status_code != 200 and resp.status_code != 201: error_content = resp.content.decode("utf-8", errors="ignore") raise PolyApiException(f"{resp.status_code}: {error_content}") return resp -def variable_update(variable_id: str, value) -> Response: +def variable_get(variable_id: str) -> httpx.Response: + return _sync_variable_get(variable_id) + + +async def variable_get_async(variable_id: str) -> httpx.Response: + return await _async_variable_get(variable_id) + + +def _sync_variable_update(variable_id: str, value) -> httpx.Response: api_key, base_url = get_api_key_and_url() headers = {"Authorization": f"Bearer {api_key}"} url = f"{base_url}/variables/{variable_id}" - resp = requests.patch(url, data={"value": value}, headers=headers) + resp = http_client.patch(url, data={"value": value}, headers=headers) if resp.status_code != 200 and resp.status_code != 201: error_content = resp.content.decode("utf-8", errors="ignore") raise PolyApiException(f"{resp.status_code}: {error_content}") - return resp \ No newline at end of file + return resp + + +async def _async_variable_update(variable_id: str, value) -> httpx.Response: + api_key, base_url = get_api_key_and_url() + headers = {"Authorization": f"Bearer {api_key}"} + url = f"{base_url}/variables/{variable_id}" + resp = await http_client.async_patch(url, data={"value": value}, headers=headers) + if resp.status_code != 200 and resp.status_code != 201: + error_content = resp.content.decode("utf-8", errors="ignore") + raise PolyApiException(f"{resp.status_code}: {error_content}") + return resp + + +def variable_update(variable_id: str, value) -> httpx.Response: + return _sync_variable_update(variable_id, value) + + +async def variable_update_async(variable_id: str, value) -> httpx.Response: + return await _async_variable_update(variable_id, value) diff --git a/polyapi/function_cli.py b/polyapi/function_cli.py index 219bbb0..2ce79a8 100644 --- a/polyapi/function_cli.py +++ b/polyapi/function_cli.py @@ -1,7 +1,6 @@ import sys from typing import Any, List, Optional -import requests - +from polyapi import http_client from polyapi.config import get_api_key_and_url from polyapi.utils import get_auth_headers, print_green, print_red, print_yellow from polyapi.parser import parse_function_code, get_jsonschema_type @@ -87,7 +86,7 @@ def function_add_or_update( sys.exit(1) headers = get_auth_headers(api_key) - resp = requests.post(url, headers=headers, json=data) + resp = http_client.post(url, headers=headers, json=data) if resp.status_code in [200, 201]: print_green("DEPLOYED") function_id = resp.json()["id"] @@ -126,5 +125,5 @@ def spec_delete(function_type: str, function_id: str): print(f"Unknown function type: {function_type}") sys.exit(1) headers = get_auth_headers(api_key) - resp = requests.delete(url, headers=headers) + resp = http_client.delete(url, headers=headers) return resp \ No newline at end of file diff --git a/polyapi/generate.py b/polyapi/generate.py index 00366c9..883f0ca 100644 --- a/polyapi/generate.py +++ b/polyapi/generate.py @@ -1,5 +1,4 @@ import json -import requests import os import uuid import shutil @@ -20,6 +19,7 @@ from .utils import add_import_to_init, get_auth_headers, init_the_init, print_green, to_func_namespace from .variables import generate_variables from .poly_tables import generate_tables +from . import http_client from .config import get_api_key_and_url, get_direct_execute_config, get_cached_generate_args SUPPORTED_FUNCTION_TYPES = { @@ -62,7 +62,7 @@ def get_specs(contexts: Optional[List[str]] = None, names: Optional[List[str]] = if get_direct_execute_config(): params["apiFunctionDirectExecute"] = "true" - resp = requests.get(url, headers=headers, params=params) + resp = http_client.get(url, headers=headers, params=params) if resp.status_code == 200: return resp.json() else: diff --git a/polyapi/http_client.py b/polyapi/http_client.py new file mode 100644 index 0000000..c97d06f --- /dev/null +++ b/polyapi/http_client.py @@ -0,0 +1,73 @@ +import asyncio +import httpx + +_sync_client: httpx.Client | None = None +_async_client: httpx.AsyncClient | None = None + + +def _get_sync_client() -> httpx.Client: + global _sync_client + if _sync_client is None: + _sync_client = httpx.Client(timeout=None) + return _sync_client + + +def _get_async_client() -> httpx.AsyncClient: + global _async_client + if _async_client is None: + _async_client = httpx.AsyncClient(timeout=None) + return _async_client + + +def post(url, **kwargs) -> httpx.Response: + return _get_sync_client().post(url, **kwargs) + + +async def async_post(url, **kwargs) -> httpx.Response: + return await _get_async_client().post(url, **kwargs) + + +def get(url, **kwargs) -> httpx.Response: + return _get_sync_client().get(url, **kwargs) + + +async def async_get(url, **kwargs) -> httpx.Response: + return await _get_async_client().get(url, **kwargs) + + +def patch(url, **kwargs) -> httpx.Response: + return _get_sync_client().patch(url, **kwargs) + + +async def async_patch(url, **kwargs) -> httpx.Response: + return await _get_async_client().patch(url, **kwargs) + + +def delete(url, **kwargs) -> httpx.Response: + return _get_sync_client().delete(url, **kwargs) + + +async def async_delete(url, **kwargs) -> httpx.Response: + return await _get_async_client().delete(url, **kwargs) + + +def request(method, url, **kwargs) -> httpx.Response: + return _get_sync_client().request(method, url, **kwargs) + + +async def async_request(method, url, **kwargs) -> httpx.Response: + return await _get_async_client().request(method, url, **kwargs) + + +def close(): + global _sync_client + if _sync_client is not None: + _sync_client.close() + _sync_client = None + +async def close_async(): + global _sync_client, _async_client + close() + if _async_client is not None: + await _async_client.aclose() + _async_client = None \ No newline at end of file diff --git a/polyapi/poly_tables.py b/polyapi/poly_tables.py index f98de9a..4a391fa 100644 --- a/polyapi/poly_tables.py +++ b/polyapi/poly_tables.py @@ -1,5 +1,5 @@ import os -import requests +from polyapi import http_client from typing_extensions import NotRequired, TypedDict from typing import ( List, @@ -88,7 +88,7 @@ def execute_query(table_id, method, query): headers = {"x-poly-execution-id": polyCustom.get("executionId")} if auth_key: headers["Authorization"] = f"Bearer {auth_key}" - response = requests.post(url, json=query, headers=headers) + response = http_client.post(url, json=query, headers=headers) response.raise_for_status() return response.json() except Exception as e: diff --git a/polyapi/prepare.py b/polyapi/prepare.py index b1580e2..b13b337 100644 --- a/polyapi/prepare.py +++ b/polyapi/prepare.py @@ -2,8 +2,7 @@ import sys import subprocess from typing import List, Tuple, Literal -import requests - +from polyapi import http_client from polyapi.utils import get_auth_headers from polyapi.config import get_api_key_and_url from polyapi.parser import parse_function_code @@ -32,7 +31,7 @@ def get_server_function_description(description: str, arguments, code: str) -> s api_key, api_url = get_api_key_and_url() headers = get_auth_headers(api_key) data = {"description": description, "arguments": arguments, "code": code} - response = requests.post(f"{api_url}/functions/server/description-generation", headers=headers, json=data) + response = http_client.post(f"{api_url}/functions/server/description-generation", headers=headers, json=data) return response.json() def get_client_function_description(description: str, arguments, code: str) -> str: @@ -40,7 +39,7 @@ def get_client_function_description(description: str, arguments, code: str) -> s headers = get_auth_headers(api_key) # Simulated API call to generate client function descriptions data = {"description": description, "arguments": arguments, "code": code} - response = requests.post(f"{api_url}/functions/client/description-generation", headers=headers, json=data) + response = http_client.post(f"{api_url}/functions/client/description-generation", headers=headers, json=data) return response.json() def fill_in_missing_function_details(deployable: DeployableRecord, code: str) -> DeployableRecord: diff --git a/polyapi/rendered_spec.py b/polyapi/rendered_spec.py index 2206417..e746a56 100644 --- a/polyapi/rendered_spec.py +++ b/polyapi/rendered_spec.py @@ -1,7 +1,7 @@ import os from typing import Optional -import requests +from polyapi import http_client from polyapi.config import get_api_key_and_url from polyapi.generate import read_cached_specs, render_spec from polyapi.typedefs import SpecificationDto @@ -31,7 +31,7 @@ def update_rendered_spec(spec: SpecificationDto): url = f"{base_url}/functions/rendered-specs" headers = {"Authorization": f"Bearer {api_key}"} - resp = requests.post(url, json=data, headers=headers) + resp = http_client.post(url, json=data, headers=headers) assert resp.status_code == 201, (resp.text, resp.status_code) @@ -40,7 +40,7 @@ def _get_spec(spec_id: str, no_types: bool = False) -> Optional[SpecificationDto url = f"{base_url}/specs" headers = {"Authorization": f"Bearer {api_key}"} params = {"noTypes": str(no_types).lower()} - resp = requests.get(url, headers=headers, params=params) + resp = http_client.get(url, headers=headers, params=params) if resp.status_code == 200: specs = resp.json() for spec in specs: diff --git a/polyapi/server.py b/polyapi/server.py index 53b173e..9ee8ae0 100644 --- a/polyapi/server.py +++ b/polyapi/server.py @@ -24,6 +24,20 @@ def {function_name}( return resp.text # type: ignore # fallback for debugging +async def {function_name}_async( +{args} +) -> {return_type_name}: + \"""{function_description} + + Function ID: {function_id} + \""" + resp = await execute_async("{function_type}", "{function_id}", {data}) + try: + return {return_action} + except: + return resp.text # type: ignore # fallback for debugging + + """ diff --git a/polyapi/sync.py b/polyapi/sync.py index 850538b..2b82c45 100644 --- a/polyapi/sync.py +++ b/polyapi/sync.py @@ -2,8 +2,7 @@ from datetime import datetime from typing import List, Dict from typing_extensions import cast # type: ignore -import requests - +from polyapi import http_client from polyapi.utils import get_auth_headers from polyapi.config import get_api_key_and_url from polyapi.parser import get_jsonschema_type @@ -35,10 +34,10 @@ def remove_deployable_function(deployable: SyncDeployment) -> bool: raise Exception("Missing api key!") headers = get_auth_headers(api_key) url = f'{deployable["instance"]}/functions/{deployable["type"].replace("-function", "")}/{deployable["id"]}' - response = requests.get(url, headers=headers) + response = http_client.get(url, headers=headers) if response.status_code != 200: return False - requests.delete(url, headers=headers) + http_client.delete(url, headers=headers) return True def remove_deployable(deployable: SyncDeployment) -> bool: @@ -65,7 +64,7 @@ def sync_function_and_get_id(deployable: SyncDeployment, code: str) -> str: "returnTypeSchema": deployable["types"]["returns"]["typeSchema"], "arguments": [{**p, "key": p["name"], "type": get_jsonschema_type(p["type"]) } for p in deployable["types"]["params"]], } - response = requests.post(url, headers=headers, json=payload) + response = http_client.post(url, headers=headers, json=payload) response.raise_for_status() return response.json()['id'] diff --git a/polyapi/utils.py b/polyapi/utils.py index fb57199..1a6d168 100644 --- a/polyapi/utils.py +++ b/polyapi/utils.py @@ -16,7 +16,7 @@ # this string should be in every __init__ file. # it contains all the imports needed for the function or variable code to run -CODE_IMPORTS = "from typing import List, Dict, Any, Optional, Callable\nfrom typing_extensions import TypedDict, NotRequired\nimport logging\nimport requests\nimport socketio # type: ignore\nfrom polyapi.config import get_api_key_and_url, get_direct_execute_config\nfrom polyapi.execute import execute, execute_post, variable_get, variable_update, direct_execute\n\n" +CODE_IMPORTS = "from typing import List, Dict, Any, Optional, Callable\nfrom typing_extensions import TypedDict, NotRequired\nimport logging\nimport requests\nimport socketio # type: ignore\nfrom polyapi.config import get_api_key_and_url, get_direct_execute_config\nfrom polyapi.execute import execute, execute_async, execute_post, execute_post_async, variable_get, variable_get_async, variable_update, variable_update_async, direct_execute, direct_execute_async\n\n" def init_the_init(full_path: str, code_imports: Optional[str] = None) -> None: @@ -73,7 +73,6 @@ def print_red(s: str): def add_type_import_path(function_name: str, arg: str) -> str: """if not basic type, coerce to camelCase and add the import path""" - # outdated og comment - for now, just treat Callables as basic types # from now, we start qualifying non-basic types :)) # e.g. Callable[[EmailAddress, Dict, Dict, Dict], None] # becomes Callable[[Set_profile_email.EmailAddress, Dict, Dict, Dict], None] diff --git a/polyapi/variables.py b/polyapi/variables.py index 1fb915d..b8220b7 100644 --- a/polyapi/variables.py +++ b/polyapi/variables.py @@ -15,6 +15,11 @@ def get() -> {variable_type}: resp = variable_get("{variable_id}") return resp.text + + @staticmethod + async def get_async() -> {variable_type}: + resp = await variable_get_async("{variable_id}") + return resp.text """ @@ -30,6 +35,11 @@ def update(value: {variable_type}): resp = variable_update("{variable_id}", value) return resp.json() + @staticmethod + async def update_async(value: {variable_type}): + resp = await variable_update_async("{variable_id}", value) + return resp.json() + @classmethod async def onUpdate(cls, callback): api_key, base_url = get_api_key_and_url() diff --git a/pyproject.toml b/pyproject.toml index dc5916e..d651f61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "polyapi-python" -version = "0.3.13" +version = "0.3.14.dev1" description = "The Python Client for PolyAPI, the IPaaS by Developers for Developers" authors = [{ name = "Dan Fellin", email = "dan@polyapi.io" }] dependencies = [ @@ -16,6 +16,7 @@ dependencies = [ "colorama==0.4.4", "python-socketio[asyncio_client]==5.11.1", "truststore>=0.8.0", + "httpx>=0.28.1" ] readme = "README.md" license = { file = "LICENSE" } diff --git a/requirements.txt b/requirements.txt index b5eb3c4..1a07d23 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ pydantic>=2.8.0 stdlib_list>=0.10.0 colorama==0.4.4 python-socketio[asyncio_client]==5.11.1 -truststore>=0.8.0 \ No newline at end of file +truststore>=0.8.0 +httpx>=0.28.1 \ No newline at end of file diff --git a/tests/test_async_proof.py b/tests/test_async_proof.py new file mode 100644 index 0000000..88efd74 --- /dev/null +++ b/tests/test_async_proof.py @@ -0,0 +1,412 @@ +"""Tests proving the sync/async split works correctly. + +These tests mock HTTP calls so no live server is needed. They verify that: +1. is_async() correctly detects sync vs async context +2. http_client uses sync Client in sync context, AsyncClient in async context +3. Sync functions (execute, direct_execute, etc.) always return httpx.Response +4. Async functions (execute_async, direct_execute_async, etc.) return coroutines +5. Parallel async execution with asyncio.gather works +""" + +import asyncio +import inspect +from unittest.mock import patch, MagicMock, AsyncMock + +import httpx +import pytest + +from polyapi import http_client +from polyapi.execute import ( + execute, + execute_async, + direct_execute, + direct_execute_async, + execute_post, + execute_post_async, + variable_get, + variable_get_async, + variable_update, + variable_update_async, + _build_direct_execute_params, + _check_endpoint_error, + _check_response_error +) +from polyapi.exceptions import PolyApiException + + +# Helpers + +def _fake_response(status_code=200, json_data=None, text="ok"): + """Build a fake httpx.Response.""" + resp = MagicMock(spec=httpx.Response) + resp.status_code = status_code + resp.text = text + resp.content = text.encode() + resp.json.return_value = {} if json_data is None else json_data + return resp + + +# 1. http_client sync / async client pairing + +class TestHttpClientPairing: + """Verify that the sync helpers call httpx.Client and the async helpers + call httpx.AsyncClient.""" + + def setup_method(self): + # Reset singletons so each test starts fresh + http_client._sync_client = None + http_client._async_client = None + + def teardown_method(self): + http_client._sync_client = None + http_client._async_client = None + + @patch.object(httpx.Client, "post", return_value=_fake_response()) + def test_sync_post_uses_sync_client(self, mock_post): + resp = http_client.post("https://example.com", json={}) + mock_post.assert_called_once() + assert resp.status_code == 200 + # The sync client should have been created + assert http_client._sync_client is not None + assert http_client._async_client is None + + @patch.object(httpx.AsyncClient, "post", new_callable=AsyncMock, return_value=_fake_response()) + def test_async_post_uses_async_client(self, mock_post): + async def _run(): + return await http_client.async_post("https://example.com", json={}) + + resp = asyncio.run(_run()) + mock_post.assert_called_once() + assert resp.status_code == 200 + assert http_client._async_client is not None + + @patch.object(httpx.Client, "get", return_value=_fake_response()) + def test_sync_get(self, mock_get): + resp = http_client.get("https://example.com") + mock_get.assert_called_once() + assert resp.status_code == 200 + + @patch.object(httpx.AsyncClient, "get", new_callable=AsyncMock, return_value=_fake_response()) + def test_async_get(self, mock_get): + async def _run(): + return await http_client.async_get("https://example.com") + + resp = asyncio.run(_run()) + mock_get.assert_called_once() + assert resp.status_code == 200 + + @patch.object(httpx.Client, "patch", return_value=_fake_response()) + def test_sync_patch(self, mock_patch_method): + resp = http_client.patch("https://example.com", data={"v": 1}) + mock_patch_method.assert_called_once() + assert resp.status_code == 200 + + @patch.object(httpx.AsyncClient, "patch", new_callable=AsyncMock, return_value=_fake_response()) + def test_async_patch(self, mock_patch_method): + async def _run(): + return await http_client.async_patch("https://example.com", data={"v": 1}) + + resp = asyncio.run(_run()) + mock_patch_method.assert_called_once() + + @patch.object(httpx.Client, "delete", return_value=_fake_response()) + def test_sync_delete(self, mock_delete): + resp = http_client.delete("https://example.com") + mock_delete.assert_called_once() + + @patch.object(httpx.AsyncClient, "delete", new_callable=AsyncMock, return_value=_fake_response()) + def test_async_delete(self, mock_delete): + async def _run(): + return await http_client.async_delete("https://example.com") + + resp = asyncio.run(_run()) + mock_delete.assert_called_once() + + @patch.object(httpx.Client, "request", return_value=_fake_response()) + def test_sync_request(self, mock_request): + resp = http_client.request("POST", "https://example.com") + mock_request.assert_called_once() + + @patch.object(httpx.AsyncClient, "request", new_callable=AsyncMock, return_value=_fake_response()) + def test_async_request(self, mock_request): + async def _run(): + return await http_client.async_request("POST", "https://example.com") + + resp = asyncio.run(_run()) + mock_request.assert_called_once() + + +# 3. execute() / execute_async() + +_CONFIG_PATCH = patch( + "polyapi.execute.get_api_key_and_url", + return_value=("fake-key", "https://api.example.com"), +) +_MTLS_PATCH = patch( + "polyapi.execute.get_mtls_config", + return_value=(False, None, None, None), +) + + +class TestExecute: + """execute() always returns httpx.Response (sync). + execute_async() always returns a coroutine that resolves to httpx.Response.""" + + @_CONFIG_PATCH + @patch("polyapi.http_client.post", return_value=_fake_response(200, text='"hello"')) + def test_sync_returns_response(self, mock_post, _mock_config): + result = execute("server", "some-id", {}) + assert isinstance(result, MagicMock) # our fake Response + assert result.status_code == 200 + mock_post.assert_called_once() + + @_CONFIG_PATCH + @patch("polyapi.http_client.async_post", new_callable=AsyncMock, return_value=_fake_response(200, text='"hello"')) + def test_async_returns_coroutine_then_response(self, mock_post, _mock_config): + async def _run(): + coro = execute_async("server", "some-id", {}) + assert inspect.isawaitable(coro) + return await coro + + result = asyncio.run(_run()) + assert result.status_code == 200 + mock_post.assert_called_once() + + @_CONFIG_PATCH + @patch("polyapi.http_client.post", return_value=_fake_response(200, text='"hello"')) + def test_sync_calls_correct_url(self, mock_post, _mock_config): + execute("server", "abc-123", {"arg": 1}) + call_args = mock_post.call_args + assert "/functions/server/abc-123/execute" in call_args[0][0] + assert call_args[1]["json"] == {"arg": 1} + assert "Bearer fake-key" in call_args[1]["headers"]["Authorization"] + + @_CONFIG_PATCH + @patch("polyapi.http_client.post", return_value=_fake_response(200, text='"hello"')) + def test_sync_works_inside_async_context(self, mock_post, _mock_config): + """execute() (sync) should still return httpx.Response even when + called from within an async context — this is the key fix.""" + async def _run(): + result = execute("server", "some-id", {}) + # Should be a Response, NOT a coroutine + assert not inspect.isawaitable(result) + assert result.status_code == 200 + + asyncio.run(_run()) + + +# 4. direct_execute() / direct_execute_async() + +class TestDirectExecute: + + @_CONFIG_PATCH + @_MTLS_PATCH + @patch("polyapi.execute.httpx.request", return_value=_fake_response(200, text='{"result": 1}')) + @patch("polyapi.http_client.post", return_value=_fake_response( + 200, json_data={"url": "https://target.example.com", "method": "GET"}, + )) + def test_sync_returns_response(self, mock_post, mock_request, _mtls, _config): + result = direct_execute("server", "fn-id", {}) + assert result.status_code == 200 + assert "/direct-execute" in mock_post.call_args[0][0] + mock_request.assert_called_once() + + @_CONFIG_PATCH + @_MTLS_PATCH + @patch("polyapi.http_client.async_post", new_callable=AsyncMock, return_value=_fake_response( + 200, json_data={"url": "https://target.example.com", "method": "GET"}, + )) + def test_async_returns_coroutine(self, mock_post, _mtls, _config): + mock_client = AsyncMock() + mock_client.request = AsyncMock(return_value=_fake_response(200)) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def _run(): + with patch("polyapi.execute.httpx.AsyncClient", return_value=mock_client): + coro = direct_execute_async("server", "fn-id", {}) + assert inspect.isawaitable(coro) + return await coro + + result = asyncio.run(_run()) + assert result.status_code == 200 + + +# 5. execute_post() / execute_post_async() + +class TestExecutePost: + + @_CONFIG_PATCH + @patch("polyapi.http_client.post", return_value=_fake_response()) + def test_sync(self, mock_post, _config): + result = execute_post("/some/path", {"data": 1}) + assert result.status_code == 200 + assert "https://api.example.com/some/path" == mock_post.call_args[0][0] + + @_CONFIG_PATCH + @patch("polyapi.http_client.async_post", new_callable=AsyncMock, return_value=_fake_response()) + def test_async(self, mock_post, _config): + async def _run(): + coro = execute_post_async("/some/path", {"data": 1}) + assert inspect.isawaitable(coro) + return await coro + + result = asyncio.run(_run()) + assert result.status_code == 200 + + +# 6. variable_get / variable_get_async + +class TestVariableGet: + + @_CONFIG_PATCH + @patch("polyapi.http_client.get", return_value=_fake_response(200, text="42")) + def test_sync(self, mock_get, _config): + result = variable_get("var-123") + assert result.status_code == 200 + assert "/variables/var-123/value" in mock_get.call_args[0][0] + + @_CONFIG_PATCH + @patch("polyapi.http_client.async_get", new_callable=AsyncMock, return_value=_fake_response(200, text="42")) + def test_async(self, mock_get, _config): + async def _run(): + coro = variable_get_async("var-123") + assert inspect.isawaitable(coro) + return await coro + + result = asyncio.run(_run()) + assert result.status_code == 200 + + +# 7. variable_update / variable_update_async + +class TestVariableUpdate: + + @_CONFIG_PATCH + @patch("polyapi.http_client.patch", return_value=_fake_response(200)) + def test_sync(self, mock_patch_call, _config): + result = variable_update("var-123", "new-value") + assert result.status_code == 200 + + @_CONFIG_PATCH + @patch("polyapi.http_client.async_patch", new_callable=AsyncMock, return_value=_fake_response(200)) + def test_async(self, mock_patch_call, _config): + async def _run(): + coro = variable_update_async("var-123", "new-value") + assert inspect.isawaitable(coro) + return await coro + + result = asyncio.run(_run()) + assert result.status_code == 200 + + +# 8. Parallel async execution with asyncio.gather + +class TestAsyncParallelExecution: + """Prove that multiple async execute_async() calls can be gathered in parallel.""" + + @_CONFIG_PATCH + @patch("polyapi.http_client.async_post", new_callable=AsyncMock) + def test_gather_multiple_executes(self, mock_post, _config): + responses = [_fake_response(200, text=f'"result-{i}"') for i in range(5)] + mock_post.side_effect = responses + + async def _run(): + coros = [execute_async("server", f"fn-{i}", {}) for i in range(5)] + for c in coros: + assert inspect.isawaitable(c) + return await asyncio.gather(*coros) + + results = asyncio.run(_run()) + assert len(results) == 5 + assert mock_post.call_count == 5 + for i, r in enumerate(results): + assert r.text == f'"result-{i}"' + + @_CONFIG_PATCH + @patch("polyapi.http_client.async_post", new_callable=AsyncMock) + def test_gather_is_faster_than_sequential(self, mock_post, _config): + """Measure both sequential and parallel execution in the same run, + then assert parallel < sequential. Avoids flaky CI failures from + hardcoded time thresholds.""" + import time + + async def _slow_post(*args, **kwargs): + await asyncio.sleep(0.1) + return _fake_response(200, text='"done"') + + mock_post.side_effect = _slow_post + + async def _run(): + # Sequential: await one at a time + seq_start = time.monotonic() + for i in range(5): + await execute_async("server", f"fn-{i}", {}) + seq_elapsed = time.monotonic() - seq_start + + # Parallel: gather all at once + par_start = time.monotonic() + results = await asyncio.gather( + *[execute_async("server", f"fn-{i}", {}) for i in range(5)] + ) + par_elapsed = time.monotonic() - par_start + + return results, seq_elapsed, par_elapsed + + results, seq_elapsed, par_elapsed = asyncio.run(_run()) + assert len(results) == 5 + assert par_elapsed < seq_elapsed, ( + f"Parallel ({par_elapsed:.2f}s) should be faster than sequential ({seq_elapsed:.2f}s)" + ) + + +# 9. Helper: _build_direct_execute_params + +class TestBuildDirectExecuteParams: + def test_strips_url(self): + params = _build_direct_execute_params({"url": "https://x.com", "method": "GET"}) + assert "url" not in params + assert params["method"] == "GET" + + def test_converts_max_redirects_positive(self): + params = _build_direct_execute_params({"url": "u", "maxRedirects": 5}) + assert "maxRedirects" not in params + assert params["follow_redirects"] is True + + def test_converts_max_redirects_zero(self): + params = _build_direct_execute_params({"url": "u", "maxRedirects": 0}) + assert params["follow_redirects"] is False + + def test_no_mutation_of_input(self): + original = {"url": "u", "method": "POST", "maxRedirects": 3} + _build_direct_execute_params(original) + assert "url" in original + assert "maxRedirects" in original + + +# 10. _check_endpoint_error vs _check_response_error + +class TestCheckErrorBehaviorDifference: + """_check_endpoint_error raises on api errors; _check_response_error only logs.""" + + + def test_endpoint_error_raises_for_api_with_logs(self): + """_check_endpoint_error raises PolyApiException for api functions.""" + + resp = _fake_response(status_code=500, text="server broke") + with patch.dict("os.environ", {"LOGS_ENABLED": "1"}): + with pytest.raises(PolyApiException, match="500"): + _check_endpoint_error(resp, "api", "fn-1", {}) + + def test_response_error_logs_for_api_with_logs(self): + """_check_response_error only logs (no raise) for api functions.""" + resp = _fake_response(status_code=500, text="server broke") + with patch.dict("os.environ", {"LOGS_ENABLED": "1"}): + # Should NOT raise — just logs + _check_response_error(resp, "api", "fn-1", {}) + + def test_both_raise_for_non_api(self): + """Both functions raise PolyApiException for non-api function types.""" + resp = _fake_response(status_code=500, text="server broke") + with pytest.raises(PolyApiException, match="500"): + _check_endpoint_error(resp, "server", "fn-1", {}) diff --git a/tests/test_rendered_spec.py b/tests/test_rendered_spec.py index f2bf7a5..9c43f07 100644 --- a/tests/test_rendered_spec.py +++ b/tests/test_rendered_spec.py @@ -40,7 +40,7 @@ def test_get_and_update_rendered_spec_fail(self, _get_spec): self.assertEqual(_get_spec.call_count, 1) self.assertFalse(updated) - @patch("polyapi.rendered_spec.requests.post") + @patch("polyapi.http_client.post") @patch("polyapi.rendered_spec._get_spec") def test_get_and_update_rendered_spec_success(self, _get_spec, post): """ pass in a bad id to update and make sure it returns False diff --git a/tests/test_tabi.py b/tests/test_tabi.py index 0241191..3119d6f 100644 --- a/tests/test_tabi.py +++ b/tests/test_tabi.py @@ -682,7 +682,7 @@ def test_execute_query_uses_absolute_url_and_auth_header(self): return_value=("test-api-key", "https://na1.polyapi.io"), ), patch( - "polyapi.poly_tables.requests.post", return_value=response + "polyapi.http_client.post", return_value=response ) as post_mock, ): result = execute_query("table-id-123", "select", {"where": {"id": "abc"}})