diff --git a/polyapi/client.py b/polyapi/client.py index 2c3068e..7828944 100644 --- a/polyapi/client.py +++ b/polyapi/client.py @@ -1,7 +1,10 @@ +import ast +import symtable as symtable_mod from typing import Any, Dict, List, Tuple from polyapi.typedefs import PropertySpecification from polyapi.utils import parse_arguments, get_type_and_def +from polyapi.constants import SAFE_IMPORT_MODULES DEFS_TEMPLATE = """ from typing import List, Dict, Any, TypedDict @@ -10,20 +13,223 @@ """ -def _wrap_code_in_try_except(function_name: str, code: str) -> str: - """ this is necessary because client functions with imports will blow up ALL server functions, - even if they don't use them. - because the server function will try to load all client functions when loading the library +def _is_safe_import(node: ast.stmt) -> bool: + """Check if an import statement is safe to place at module scope. + + Safe imports are stdlib and typing modules that will never raise ImportError. + """ + if isinstance(node, ast.Import): + return all( + alias.name.split('.')[0] in SAFE_IMPORT_MODULES + for alias in node.names + ) + if isinstance(node, ast.ImportFrom): + module = node.module or '' + return module.split('.')[0] in SAFE_IMPORT_MODULES + return False + + +def _rhs_is_type_construct(node: ast.expr) -> bool: + """Check if an assignment RHS is a typing construct. + + This is the ONE narrow heuristic we still need because symtable + can't distinguish `X = Literal["a"]` (type alias) from `x = foo()` (runtime). + + We check the VALUE, not the name — much more reliable than naming conventions. """ - prefix = """logger = logging.getLogger("poly") -try: + # X = Literal[...], X = Dict[str, Any], X = list[Foo], X = Union[...] + if isinstance(node, ast.Subscript): + return True + # X = str | int | float new Union + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + return True + # X = TypedDict("X", {...}) — functional form + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + if node.func.id in ('TypedDict', 'NamedTuple', 'NewType'): + return True + return False + + +def _extract_type_definitions(code: str) -> Tuple[str, str, str]: + """Split client function code into type definitions and runtime code. + + Uses symtable for definitive classification + dependency tracking. + Uses AST only for source line extraction. + + Returns: + (type_imports_code, type_defs_code, runtime_code) + """ + try: + tree = ast.parse(code) + st = symtable_mod.symtable(code, '', 'exec') + except SyntaxError: + return "", "", code + + lines = code.split('\n') + + # Phase 1: Build child table index — name -> type ('class' | 'function') + child_types: dict[str, str] = {} + child_tables: dict[str, symtable_mod.SymbolTable] = {} + for child in st.get_children(): + child_types[child.get_name()] = child.get_type() + child_tables[child.get_name()] = child + + # Phase 2: Identify all class names — these are ALWAYS module-scope + class_names: set[str] = { + name for name, kind in child_types.items() if kind == 'class' + } + + # Phase 2b: type aliases (Python 3.12+): type X = ... + if hasattr(ast, 'TypeAlias'): + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.TypeAlias) and isinstance(node.name, ast.Name): + class_names.add(node.name.id) + + # Phase 3: Compute transitive dependencies of all classes + # Any module-level symbol that a class references (directly or transitively) + # must also be at module scope + module_scope_names: set[str] = set(class_names) + + # Get all module-level assigned symbol names for reference + module_level_symbols: set[str] = { + sym.get_name() for sym in st.get_symbols() if sym.is_assigned() + } + + # BFS: find all module-level symbols reachable from classes + queue = list(class_names) + while queue: + name = queue.pop() + if name not in child_tables: + continue + for sym in child_tables[name].get_symbols(): + if sym.is_free() or (sym.is_global() and sym.is_referenced()): + dep = sym.get_name() + if dep in module_level_symbols and dep not in module_scope_names: + module_scope_names.add(dep) + queue.append(dep) # transitively check this dep's deps + + # Phase 4: Classify each AST node using the symtable results + type_import_lines: set[int] = set() + type_def_lines: set[int] = set() + + prev_was_type_def = False + + for node in ast.iter_child_nodes(tree): + start = node.lineno - 1 + end = node.end_lineno or start + 1 + + is_type_import = False + is_type_def = False + + # Imports: safe typing/stdlib imports go to module scope + if isinstance(node, (ast.Import, ast.ImportFrom)): + is_type_import = _is_safe_import(node) + + # Class definitions: symtable confirmed these are classes + elif isinstance(node, ast.ClassDef): + is_type_def = node.name in class_names # always True, but explicit + + # type aliases (Python 3.12+): type X = ... + elif hasattr(ast, 'TypeAlias') and isinstance(node, ast.TypeAlias): + is_type_def = True + + # Assignments: check if target is in our module_scope_names set + elif isinstance(node, ast.Assign) and len(node.targets) == 1: + if isinstance(node.targets[0], ast.Name): + is_type_def = node.targets[0].id in module_scope_names + + # Annotated assignments with value + elif isinstance(node, ast.AnnAssign) and node.value is not None: + if isinstance(node.target, ast.Name): + is_type_def = node.target.id in module_scope_names + + # Function definitions: NEVER module scope (these are runtime) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + is_type_def = False + + # Docstrings following type defs: keep with them + elif (isinstance(node, ast.Expr) + and isinstance(node.value, ast.Constant) + and isinstance(node.value.value, str) + and prev_was_type_def): + is_type_def = True + + if is_type_import: + for i in range(start, end): + type_import_lines.add(i) + if is_type_def: + for i in range(start, end): + type_def_lines.add(i) + + prev_was_type_def = is_type_def or is_type_import + + # Phase 5: Also promote assignments that LOOK like type aliases + # even if no class references them yet. + # This catches stuff like: DatadogStatus = Literal[...] when only used by functions + # symtable can't distinguish type aliases from variables, + # so this is the ONE remaining heuristic — but scoped narrowly to + # assignments whose RHS is a typing construct (Subscript/BinOp with |) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.Assign) and len(node.targets) == 1: + target = node.targets[0] + if isinstance(target, ast.Name) and target.id not in module_scope_names: + if _rhs_is_type_construct(node.value): + start = node.lineno - 1 + end = node.end_lineno or start + 1 + for i in range(start, end): + type_def_lines.add(i) + module_scope_names.add(target.id) + + # Build output + imports_out: list[str] = [] + types_out: list[str] = [] + runtime_out: list[str] = [] + for i, line in enumerate(lines): + if i in type_import_lines: + imports_out.append(line) + elif i in type_def_lines: + types_out.append(line) + else: + runtime_out.append(line) + + return ( + '\n'.join(imports_out).strip(), + '\n'.join(types_out).strip(), + '\n'.join(runtime_out).strip(), + ) + + +def _wrap_code_in_try_except(function_name: str, code: str) -> Tuple[str, str]: + """Split client code: types at module scope, runtime in try/except. + + Returns: + (module_scope_code, try_except_code) + + module_scope_code: safe imports + type definitions (always available) + try_except_code: runtime code wrapped in try/except ImportError """ - suffix = f"""except ImportError as e: - logger.warning("Failed to import client function '{function_name}', function unavailable: " + str(e))""" + type_imports, type_defs, runtime_code = _extract_type_definitions(code) + + # Build module-scope section + module_parts = [] + if type_imports: + module_parts.append(type_imports) + if type_defs: + module_parts.append(type_defs) + module_scope = '\n\n'.join(module_parts) - lines = code.split("\n") - code = "\n ".join(lines) - return "".join([prefix, code, "\n", suffix]) + # Build try/except section for runtime code + prefix = f'logger = logging.getLogger("poly")\ntry:\n ' + suffix = ( + f"\nexcept ImportError as e:\n" + f" logger.warning(\"Failed to import client function " + f"'{function_name}', function unavailable: \" + str(e))" + ) + + indented = '\n '.join(runtime_code.split('\n')) + wrapped = prefix + indented + suffix + + return module_scope, wrapped def render_client_function( @@ -31,7 +237,16 @@ def render_client_function( code: str, arguments: List[PropertySpecification], return_type: Dict[str, Any], -) -> Tuple[str, str]: +) -> Tuple[str, str, str]: + """Render a client function into three parts. + + Returns: + (module_scope_types, wrapped_runtime, func_type_defs) + + module_scope_types: type definitions to place at module scope (deduplicated by caller) + wrapped_runtime: function code wrapped in try/except + func_type_defs: SDK-generated type stubs for the {FuncName}.py file + """ args, args_def = parse_arguments(function_name, arguments) return_type_name, return_type_def = get_type_and_def(return_type) # type: ignore func_type_defs = DEFS_TEMPLATE.format( @@ -39,6 +254,6 @@ def render_client_function( return_type_def=return_type_def, ) - code = _wrap_code_in_try_except(function_name, code) + module_scope, wrapped = _wrap_code_in_try_except(function_name, code) - return code + "\n\n", func_type_defs \ No newline at end of file + return module_scope, wrapped + "\n\n", func_type_defs diff --git a/polyapi/constants.py b/polyapi/constants.py index 16bbeaf..1a502a2 100644 --- a/polyapi/constants.py +++ b/polyapi/constants.py @@ -23,5 +23,15 @@ BASIC_PYTHON_TYPES = set(PYTHON_TO_JSONSCHEMA_TYPE_MAP.keys()) +# initial pass +SAFE_IMPORT_MODULES = { + "typing", "typing_extensions", "types", + "re", "os", "sys", "json", "datetime", "math", + "collections", "enum", "dataclasses", "abc", + "functools", "itertools", "operator", + "urllib", "urllib.parse", "pathlib", + "copy", "hashlib", "uuid", +} + # TODO wire this up to config-variables in future so clients can modify SUPPORT_EMAIL = 'support@polyapi.io' \ No newline at end of file diff --git a/polyapi/generate.py b/polyapi/generate.py index 00366c9..9bce351 100644 --- a/polyapi/generate.py +++ b/polyapi/generate.py @@ -4,6 +4,7 @@ import uuid import shutil import logging +import ast import tempfile from copy import deepcopy @@ -22,6 +23,11 @@ from .poly_tables import generate_tables from .config import get_api_key_and_url, get_direct_execute_config, get_cached_generate_args +# Track emitted type definitions per __init__.py for deduplication +# Maps: directory_path -> {type_name -> source_code} +_emitted_types: dict[str, dict[str, str]] = {} + + SUPPORTED_FUNCTION_TYPES = { "apiFunction", "authFunction", @@ -342,7 +348,7 @@ def generate(contexts: Optional[List[str]] = None, names: Optional[List[str]] = generate_msg = f"Generating Poly Python SDK for contexts ${contexts}..." if contexts else "Generating Poly Python SDK..." print(generate_msg, end="", flush=True) remove_old_library() - + _emitted_types.clear() specs = get_specs(contexts=contexts, names=names, ids=ids, no_types=no_types) cache_specs(specs) @@ -397,7 +403,16 @@ def clear() -> None: print("Cleared!") -def render_spec(spec: SpecificationDto) -> Tuple[str, str]: +def render_spec(spec: SpecificationDto) -> Tuple[str, str, str]: + """Render a spec into generated code. + + Returns: + (module_scope_types, func_str, func_type_defs) + + module_scope_types: type definitions for module scope (client functions only) + func_str: function code (wrapped in try/except for client functions) + func_type_defs: type stubs for the {FuncName}.py IDE helper file + """ function_type = spec["type"] function_description = spec["description"] function_name = spec["name"] @@ -410,10 +425,7 @@ def render_spec(spec: SpecificationDto) -> Tuple[str, str]: assert spec["function"] # Handle cases where arguments might be missing or None if spec["function"].get("arguments"): - arguments = [ - arg for arg in spec["function"]["arguments"] - ] - + arguments = [arg for arg in spec["function"]["arguments"]] # Handle cases where returnType might be missing or None if spec["function"].get("returnType"): return_type = spec["function"]["returnType"] @@ -421,49 +433,31 @@ def render_spec(spec: SpecificationDto) -> Tuple[str, str]: # Provide a fallback return type when missing return_type = {"kind": "any"} + module_scope_types = "" + if function_type == "apiFunction": func_str, func_type_defs = render_api_function( - function_type, - function_name, - function_id, - function_description, - arguments, - return_type, + function_type, function_name, function_id, + function_description, arguments, return_type, ) elif function_type == "customFunction": - func_str, func_type_defs = render_client_function( - function_name, - spec.get("code", ""), - arguments, - return_type, + module_scope_types, func_str, func_type_defs = render_client_function( + function_name, spec.get("code", ""), arguments, return_type, ) elif function_type == "serverFunction": func_str, func_type_defs = render_server_function( - function_type, - function_name, - function_id, - function_description, - arguments, - return_type, + function_type, function_name, function_id, + function_description, arguments, return_type, ) elif function_type == "authFunction": func_str, func_type_defs = render_auth_function( - function_type, - function_name, - function_id, - function_description, - arguments, - return_type, + function_type, function_name, function_id, + function_description, arguments, return_type, ) elif function_type == "webhookHandle": func_str, func_type_defs = render_webhook_handle( - function_type, - function_context, - function_name, - function_id, - function_description, - arguments, - return_type, + function_type, function_context, function_name, + function_id, function_description, arguments, return_type, ) if X_POLY_REF_WARNING in func_type_defs: @@ -471,7 +465,62 @@ def render_spec(spec: SpecificationDto) -> Tuple[str, str]: # let's add a more user friendly error explaining what is going on func_type_defs = func_type_defs.replace(X_POLY_REF_WARNING, X_POLY_REF_BETTER_WARNING) - return func_str, func_type_defs + return module_scope_types, func_str, func_type_defs + + +def _deduplicate_types(full_path: str, module_scope_types: str) -> str: + """Remove type definitions that have already been emitted for this context. + + Uses AST to parse the module_scope_types string, identifies each + top-level definition by name, and strips duplicates. + + Keeps definitions whose name hasn't been seen, or whose content differs + from the previously emitted version (last-writer-wins for changed types). + """ + if not module_scope_types.strip(): + return "" + + if full_path not in _emitted_types: + _emitted_types[full_path] = {} + + seen = _emitted_types[full_path] + + try: + tree = ast.parse(module_scope_types) + except SyntaxError: + return module_scope_types # Can't parse, emit as-is + + lines = module_scope_types.split('\n') + keep_lines: set[int] = set(range(len(lines))) # Start by keeping all + + for node in ast.iter_child_nodes(tree): + name = None + if isinstance(node, ast.ClassDef): + name = node.name + elif isinstance(node, ast.Assign) and len(node.targets) == 1: + target = node.targets[0] + if isinstance(target, ast.Name): + name = target.id + elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + name = node.target.id + + if not name: + continue + + start = node.lineno - 1 + end = node.end_lineno or start + 1 + source = '\n'.join(lines[start:end]) + + if name in seen and seen[name] == source: + # Exact duplicate — skip it + for i in range(start, end): + keep_lines.discard(i) + else: + # New or changed — keep it, update registry + seen[name] = source + + return '\n'.join(line for i, line in enumerate(lines) if i in keep_lines).strip() + def add_function_file( @@ -479,62 +528,58 @@ def add_function_file( function_name: str, spec: SpecificationDto, ): - """ - Atomically add a function file to prevent partial corruption during generation failures. + """Atomically add a function to a context's __init__.py. - This function generates all content first, then writes files atomically using temporary files - to ensure that either the entire operation succeeds or no changes are made to the filesystem. + For client functions, type definitions are placed at module scope + (outside try/except) and deduplicated across sibling functions. """ try: - # first lets add the import to the __init__ init_the_init(full_path) - func_str, func_type_defs = render_spec(spec) + module_scope_types, func_str, func_type_defs = render_spec(spec) if not func_str: - # If render_spec failed and returned empty string, don't create any files raise Exception("Function rendering failed - empty function string returned") - # Prepare all content first before writing any files + # Deduplicate types against previously emitted types for this context + unique_types = _deduplicate_types(full_path, module_scope_types) + func_namespace = to_func_namespace(function_name) init_path = os.path.join(full_path, "__init__.py") func_file_path = os.path.join(full_path, f"{func_namespace}.py") - # Read current __init__.py content if it exists init_content = "" if os.path.exists(init_path): with open(init_path, "r", encoding='utf-8') as f: init_content = f.read() - # Prepare new content to append to __init__.py - new_init_content = init_content + f"\n\nfrom . import {func_namespace}\n\n{func_str}" + # Build new content: import + module-scope types + wrapped function + new_parts = [init_content, f"\n\nfrom . import {func_namespace}\n"] + if unique_types: + new_parts.append(f"\n{unique_types}\n") + new_parts.append(f"\n{func_str}") + new_init_content = ''.join(new_parts) - # Use temporary files for atomic writes - # Write to __init__.py atomically + # Atomic writes with tempfile.NamedTemporaryFile(mode="w", delete=False, dir=full_path, suffix=".tmp", encoding='utf-8') as temp_init: temp_init.write(new_init_content) temp_init_path = temp_init.name - # Write to function file atomically with tempfile.NamedTemporaryFile(mode="w", delete=False, dir=full_path, suffix=".tmp", encoding='utf-8') as temp_func: temp_func.write(func_type_defs) temp_func_path = temp_func.name - # Atomic operations: move temp files to final locations shutil.move(temp_init_path, init_path) shutil.move(temp_func_path, func_file_path) except Exception as e: - # Clean up any temporary files that might have been created try: if 'temp_init_path' in locals() and os.path.exists(temp_init_path): os.unlink(temp_init_path) if 'temp_func_path' in locals() and os.path.exists(temp_func_path): os.unlink(temp_func_path) except: - pass # Best effort cleanup - - # Re-raise the original exception + pass raise e diff --git a/polyapi/utils.py b/polyapi/utils.py index fb57199..08a12be 100644 --- a/polyapi/utils.py +++ b/polyapi/utils.py @@ -16,8 +16,16 @@ # this string should be in every __init__ file. # it contains all the imports needed for the function or variable code to run -CODE_IMPORTS = "from typing import List, Dict, Any, Optional, Callable\nfrom typing_extensions import TypedDict, NotRequired\nimport logging\nimport requests\nimport socketio # type: ignore\nfrom polyapi.config import get_api_key_and_url, get_direct_execute_config\nfrom polyapi.execute import execute, execute_post, variable_get, variable_update, direct_execute\n\n" - +CODE_IMPORTS = ( + "from __future__ import annotations\n" # main char + "from typing import List, Dict, Any, Optional, Callable, Union\n" + "from typing_extensions import TypedDict, NotRequired, Literal\n" + "import logging\n" + "import requests\n" + "import socketio # type: ignore\n" + "from polyapi.config import get_api_key_and_url, get_direct_execute_config\n" + "from polyapi.execute import execute, execute_post, variable_get, variable_update, direct_execute\n\n" +) def init_the_init(full_path: str, code_imports: Optional[str] = None) -> None: init_path = os.path.join(full_path, "__init__.py")