This repository was archived by the owner on Jun 5, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 90
Expand file tree
/
Copy pathrulematcher.py
More file actions
226 lines (177 loc) · 7.88 KB
/
rulematcher.py
File metadata and controls
226 lines (177 loc) · 7.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import copy
import fnmatch
from abc import ABC, abstractmethod
from asyncio import Lock
from typing import Dict, List, Optional
import structlog
from codegate.clients.clients import ClientType
from codegate.db import models as db_models
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
from codegate.muxing import models as mux_models
logger = structlog.get_logger("codegate")
_muxrules_sgtn = None
_singleton_lock = Lock()
class MuxMatchingError(Exception):
"""An exception for muxing matching errors."""
pass
async def get_muxing_rules_registry():
"""Returns a singleton instance of the muxing rules registry."""
global _muxrules_sgtn
if _muxrules_sgtn is None:
async with _singleton_lock:
if _muxrules_sgtn is None:
_muxrules_sgtn = MuxingRulesinWorkspaces()
return _muxrules_sgtn
class ModelRoute:
"""A route for a model."""
def __init__(
self,
model: db_models.ProviderModel,
endpoint: db_models.ProviderEndpoint,
auth_material: db_models.ProviderAuthMaterial,
):
self.model = model
self.endpoint = endpoint
self.auth_material = auth_material
class MuxingRuleMatcher(ABC):
"""Base class for matching muxing rules."""
def __init__(self, route: ModelRoute, mux_rule: mux_models.MuxRule):
self._route = route
self._mux_rule = mux_rule
@abstractmethod
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""Return True if the rule matches the thing_to_match."""
pass
def destination(self) -> ModelRoute:
"""Return the destination of the rule."""
return self._route
class MuxingMatcherFactory:
"""Factory for creating muxing matchers."""
@staticmethod
def create(
db_mux_rule: db_models.MuxRule,
db_provider_endpoint: db_models.ProviderEndpoint,
route: ModelRoute,
) -> MuxingRuleMatcher:
"""Create a muxing matcher for the given endpoint and model."""
factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = {
mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher,
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher,
mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher,
}
try:
# Initialize the MuxingRuleMatcher
mux_rule = mux_models.MuxRule.from_db_models(db_mux_rule, db_provider_endpoint)
return factory[mux_rule.matcher_type](route, mux_rule)
except KeyError:
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")
class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
"""A catch all muxing rule matcher."""
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
logger.info("Catch all rule matched")
return True
class FileMuxingRuleMatcher(MuxingRuleMatcher):
"""A file muxing rule matcher."""
def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
"""
Extract filenames from the request data.
"""
try:
body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client)
return body_extractor.extract_unique_filenames(data)
except BodyCodeSnippetExtractorError as e:
logger.error(f"Error extracting filenames from request: {e}")
raise MuxMatchingError("Error extracting filenames from request")
def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> bool:
"""
Check if the matcher is in the request filenames.
The matcher is treated as a glob pattern and matched against the filenames.
"""
# Empty matcher_blob means we match everything
if not self._mux_rule.matcher:
return True
filenames_to_match = self._extract_request_filenames(detected_client, data)
# _mux_rule.matcher is a glob pattern. We match if any of the filenames
# match the pattern.
is_filename_match = any(
fnmatch.fnmatch(filename, self._mux_rule.matcher) for filename in filenames_to_match
)
return is_filename_match
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Return True if the matcher is in one of the request filenames.
"""
is_rule_matched = self._is_matcher_in_filenames(
thing_to_match.client_type, thing_to_match.body
)
if is_rule_matched:
logger.info("Filename rule matched", matcher=self._mux_rule.matcher)
return is_rule_matched
class RequestTypeAndFileMuxingRuleMatcher(FileMuxingRuleMatcher):
"""A request type and file muxing rule matcher."""
def _is_request_type_match(self, is_fim_request: bool) -> bool:
"""
Check if the request type matches the MuxMatcherType.
"""
incoming_request_type = "fim_filename" if is_fim_request else "chat_filename"
if incoming_request_type == self._mux_rule.matcher_type:
return True
return False
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Return True if the matcher is in one of the request filenames and
if the request type matches the MuxMatcherType.
"""
is_rule_matched = self._is_matcher_in_filenames(
thing_to_match.client_type, thing_to_match.body
) and self._is_request_type_match(thing_to_match.is_fim_request)
if is_rule_matched:
logger.info(
"Request type and rule matched",
matcher=self._mux_rule.matcher,
is_fim_request=thing_to_match.is_fim_request,
)
return is_rule_matched
class MuxingRulesinWorkspaces:
"""A thread safe dictionary to store the muxing rules in workspaces."""
def __init__(self) -> None:
super().__init__()
self._lock = Lock()
self._active_workspace = ""
self._ws_rules = {}
async def get_ws_rules(self, workspace_name: str) -> List[MuxingRuleMatcher]:
"""Get the rules for the given workspace."""
async with self._lock:
return copy.deepcopy(self._ws_rules.get(workspace_name, []))
async def set_ws_rules(self, workspace_name: str, rules: List[MuxingRuleMatcher]) -> None:
"""Set the rules for the given workspace."""
async with self._lock:
self._ws_rules[workspace_name] = rules
async def delete_ws_rules(self, workspace_name: str) -> None:
"""Delete the rules for the given workspace."""
async with self._lock:
if workspace_name in self._ws_rules:
del self._ws_rules[workspace_name]
async def set_active_workspace(self, workspace_name: str) -> None:
"""Set the active workspace."""
self._active_workspace = workspace_name
async def get_registries(self) -> List[str]:
"""Get the list of workspaces."""
async with self._lock:
return list(self._ws_rules.keys())
async def get_match_for_active_workspace(
self, thing_to_match: mux_models.ThingToMatchMux
) -> Optional[ModelRoute]:
"""Get the first match for the given thing_to_match."""
# We iterate over all the rules and return the first match
# Since we already do a deepcopy in __getitem__, we don't need to lock here
try:
rules = await self.get_ws_rules(self._active_workspace)
for rule in rules:
if rule.match(thing_to_match):
return rule.destination()
return None
except KeyError:
raise RuntimeError("No rules found for the active workspace")