-
Notifications
You must be signed in to change notification settings - Fork 226
Expand file tree
/
Copy pathtool_registry.py
More file actions
138 lines (104 loc) · 3.96 KB
/
tool_registry.py
File metadata and controls
138 lines (104 loc) · 3.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Tool plugin registry for cheetahclaws.
Provides a central registry for tool definitions, lookup, schema export,
dispatch with output truncation, and result caching for read-only tools.
"""
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
@dataclass
class ToolDef:
"""Definition of a single tool plugin.
Attributes:
name: unique tool identifier
schema: JSON-schema dict sent to the API (name, description, input_schema)
func: callable(params: dict, config: dict) -> str
read_only: True if the tool never mutates state
concurrent_safe: True if safe to run in parallel with other tools
"""
name: str
schema: Dict[str, Any]
func: Callable[[Dict[str, Any], Dict[str, Any]], str]
read_only: bool = False
concurrent_safe: bool = False
# --------------- internal state ---------------
_registry: Dict[str, ToolDef] = {}
# --------------- result cache (read-only tools only) ---------------
_CACHE_MAX = 64 # max cached entries
_cache: Dict[str, str] = {} # hash → result
_cache_order: list[str] = [] # LRU eviction order
def _cache_key(name: str, params: Dict[str, Any]) -> str:
"""Create a stable hash from tool name + params."""
raw = json.dumps({"n": name, "p": params}, sort_keys=True, default=str)
return hashlib.sha256(raw.encode()).hexdigest()[:16]
def clear_tool_cache() -> None:
"""Clear the tool result cache. Called on file writes to invalidate."""
_cache.clear()
_cache_order.clear()
# --------------- public API ---------------
def register_tool(tool_def: ToolDef) -> None:
"""Register a tool, overwriting any existing tool with the same name."""
_registry[tool_def.name] = tool_def
def get_tool(name: str) -> Optional[ToolDef]:
"""Look up a tool by name. Returns None if not found."""
return _registry.get(name)
def get_all_tools() -> List[ToolDef]:
"""Return all registered tools (insertion order)."""
return list(_registry.values())
def get_tool_schemas() -> List[Dict[str, Any]]:
"""Return the schemas of all registered tools (for API tool parameter)."""
return [t.schema for t in _registry.values()]
def execute_tool(
name: str,
params: Dict[str, Any],
config: Dict[str, Any],
max_output: int = 32000,
) -> str:
"""Dispatch a tool call by name.
Args:
name: tool name
params: tool input parameters dict
config: runtime configuration dict
max_output: maximum allowed output length in characters
Returns:
Tool result string, possibly truncated.
"""
tool = get_tool(name)
if tool is None:
return f"Error: tool '{name}' not found."
# Cache hit for read-only tools (same name + same params = same result)
use_cache = tool.read_only
if use_cache:
key = _cache_key(name, params)
if key in _cache:
return _cache[key]
else:
# Write tools invalidate cache (file content may have changed)
if name in ("Write", "Edit", "Bash", "NotebookEdit"):
clear_tool_cache()
try:
result = tool.func(params, config)
except Exception as e:
return f"Error executing {name}: {e}"
# Store in cache for read-only tools
if use_cache:
_cache[key] = result
_cache_order.append(key)
# Evict oldest if over limit
while len(_cache_order) > _CACHE_MAX:
old = _cache_order.pop(0)
_cache.pop(old, None)
if len(result) > max_output:
first_half = max_output // 2
last_quarter = max_output // 4
truncated = len(result) - first_half - last_quarter
result = (
result[:first_half]
+ f"\n[... {truncated} chars truncated ...]\n"
+ result[-last_quarter:]
)
return result
def clear_registry() -> None:
"""Remove all registered tools. Intended for testing."""
_registry.clear()