44基于代码静态分析估算时间/空间复杂度,并推荐测试参数。
55"""
66
7+ from __future__ import annotations
8+
79import re
810
911from .base import Tool , ToolResult
@@ -56,30 +58,37 @@ def analyze_loop_complexity(code: str) -> str:
5658 Returns:
5759 估算的复杂度字符串
5860 """
59- # 统计嵌套循环层数
61+ # 循环模式
6062 loop_patterns = [
6163 r"\bfor\s*\(" ,
6264 r"\bwhile\s*\(" ,
63- r"\bfor\s+.*:\s* " , # range-based for
65+ r"\bfor\s+\w+ .*:" , # range-based for
6466 ]
6567
6668 max_nesting = 0
67- current_nesting = 0
69+ brace_depth = 0
6870
6971 lines = code .split ("\n " )
7072 for line in lines :
71- # 计算当前行的循环数
72- loop_count = 0
73- for pattern in loop_patterns :
74- loop_count += len (re .findall (pattern , line ))
75-
76- # 检测循环结束
77- brace_change = line .count ("{" ) - line .count ("}" )
78-
79- # 更新嵌套深度
80- current_nesting += loop_count
81- max_nesting = max (max_nesting , current_nesting )
82- current_nesting = max (0 , current_nesting + brace_change )
73+ # 移除单行注释
74+ if "//" in line :
75+ line = line [: line .index ("//" )]
76+
77+ # 检测当前行是否有循环
78+ has_loop = any (re .search (p , line ) for p in loop_patterns )
79+
80+ # 如果有循环,记录当前深度
81+ # brace_depth 表示当前所在的大括号层级
82+ # 循环嵌套数 = 当前大括号层级
83+ if has_loop :
84+ max_nesting = max (max_nesting , brace_depth )
85+
86+ # 处理大括号
87+ for char in line :
88+ if char == "{" :
89+ brace_depth += 1
90+ elif char == "}" :
91+ brace_depth = max (0 , brace_depth - 1 )
8392
8493 # 根据嵌套层数估算复杂度
8594 if max_nesting == 0 :
@@ -94,17 +103,18 @@ def analyze_loop_complexity(code: str) -> str:
94103 return ComplexityLevel .EXPONENTIAL
95104
96105
97- def detect_algorithm_patterns (code : str ) -> tuple [str , list [str ]]:
106+ def detect_algorithm_patterns (code : str ) -> tuple [str | None , list [str ]]:
98107 """检测常见算法模式。
99108
100109 Args:
101110 code: C++ 源代码
102111
103112 Returns:
104- (复杂度, 检测到的模式列表)
113+ (复杂度或 None, 检测到的模式列表)
114+ 如果没有检测到模式,返回 (None, [])
105115 """
106116 patterns = []
107- complexity = ComplexityLevel . LINEAR # 默认
117+ complexity = None # 默认不返回复杂度
108118
109119 # 二分查找
110120 if re .search (r"\bbinary_search\b|\blower_bound\b|\bupper_bound\b" , code ):
@@ -140,7 +150,7 @@ def detect_algorithm_patterns(code: str) -> tuple[str, list[str]]:
140150 patterns .append ("recursion" )
141151
142152 # 位运算
143- if re .search (r"1\s*<<\s*\d |bitmask|bitset" , code ):
153+ if re .search (r"1\s*<<\s*\w+ |bitmask|bitset" , code ):
144154 patterns .append ("bitmask" )
145155 complexity = ComplexityLevel .EXPONENTIAL
146156
@@ -248,25 +258,27 @@ async def execute(
248258 # 2. 检测算法模式
249259 pattern_complexity , patterns = detect_algorithm_patterns (code )
250260
251- # 3. 选择更优的复杂度估计
252- # 优先使用模式检测的结果
253- complexity_order = [
254- ComplexityLevel .CONSTANT ,
255- ComplexityLevel .LOG_N ,
256- ComplexityLevel .LINEAR ,
257- ComplexityLevel .N_LOG_N ,
258- ComplexityLevel .QUADRATIC ,
259- ComplexityLevel .CUBIC ,
260- ComplexityLevel .EXPONENTIAL ,
261- ComplexityLevel .FACTORIAL ,
262- ]
263-
264- loop_idx = complexity_order .index (loop_complexity )
265- pattern_idx = complexity_order .index (pattern_complexity )
266-
267- # 如果模式检测到更优的复杂度,使用它
268- if pattern_idx < loop_idx :
269- final_complexity = pattern_complexity
261+ # 3. 选择复杂度估计
262+ # 如果检测到算法模式,取两者中较大的(更保守的估计)
263+ if pattern_complexity is not None :
264+ complexity_order = [
265+ ComplexityLevel .CONSTANT ,
266+ ComplexityLevel .LOG_N ,
267+ ComplexityLevel .LINEAR ,
268+ ComplexityLevel .N_LOG_N ,
269+ ComplexityLevel .QUADRATIC ,
270+ ComplexityLevel .CUBIC ,
271+ ComplexityLevel .EXPONENTIAL ,
272+ ComplexityLevel .FACTORIAL ,
273+ ]
274+
275+ loop_idx = complexity_order .index (loop_complexity )
276+ pattern_idx = complexity_order .index (pattern_complexity )
277+
278+ # 取较大的复杂度(更保守)
279+ final_complexity = (
280+ pattern_complexity if pattern_idx > loop_idx else loop_complexity
281+ )
270282 else :
271283 final_complexity = loop_complexity
272284
0 commit comments