@@ -44,10 +44,15 @@ def infer_state(problem_dir: str) -> dict[str, Any]:
4444 "sol_built" : (root / "solutions" / "sol.cpp" ).exists () or any (root .glob ("solutions/sol.*" )),
4545 "brute_built" : (root / "solutions" / "brute.cpp" ).exists () or any (root .glob ("solutions/brute.*" )),
4646 "validator_ready" : (root / "files" / "val.cpp" ).exists () or any (root .glob ("files/val.*" )),
47+ "validator_accuracy" : None ,
4748 "generator_built" : (root / "files" / "gen.cpp" ).exists () or any (root .glob ("files/gen.*" )),
4849 "stress_passed" : False ,
50+ "stress_completed_rounds" : 0 ,
51+ "stress_total_rounds" : 0 ,
4952 "checker_ready" : (root / "files" / "checker.cpp" ).exists () or any (root .glob ("files/checker.*" )),
53+ "checker_accuracy" : None ,
5054 "tests_generated" : any ((root / "tests" ).glob ("*.in" )) if (root / "tests" ).exists () else False ,
55+ "generated_test_count" : len (list ((root / "tests" ).glob ("*.in" ))) if (root / "tests" ).exists () else 0 ,
5156 "packaged" : (root / "problem.xml" ).exists (),
5257 }
5358
@@ -94,6 +99,26 @@ def deny(reason: str) -> None:
9499 print (json .dumps (output , ensure_ascii = False ))
95100
96101
102+ def quality_summary (state : dict [str , Any ]) -> str :
103+ return (
104+ "Workflow state: "
105+ f"created={ state ['created' ]} , "
106+ f"sol_built={ state ['sol_built' ]} , "
107+ f"brute_built={ state ['brute_built' ]} , "
108+ f"validator_ready={ state ['validator_ready' ]} , "
109+ f"validator_accuracy={ state .get ('validator_accuracy' )} , "
110+ f"generator_built={ state ['generator_built' ]} , "
111+ f"stress_passed={ state ['stress_passed' ]} , "
112+ f"stress_completed_rounds={ state .get ('stress_completed_rounds' , 0 )} , "
113+ f"stress_total_rounds={ state .get ('stress_total_rounds' , 0 )} , "
114+ f"checker_ready={ state ['checker_ready' ]} , "
115+ f"checker_accuracy={ state .get ('checker_accuracy' )} , "
116+ f"tests_generated={ state ['tests_generated' ]} , "
117+ f"generated_test_count={ state .get ('generated_test_count' , 0 )} , "
118+ f"packaged={ state ['packaged' ]} ."
119+ )
120+
121+
97122def pre_tool (payload : dict [str , Any ]) -> int :
98123 short_name = tool_short_name (payload .get ("tool_name" , "" ))
99124 problem_dir = get_problem_dir (payload )
@@ -109,11 +134,11 @@ def pre_tool(payload: dict[str, Any]) -> int:
109134 "solution_build_brute" : "必须先构建标准解 sol,再构建 brute。" ,
110135 "solution_build" : "必须先运行 problem_create 创建题目目录。" ,
111136 "validator_build" : "必须先完成 problem_create、solution_build(sol)、solution_build(brute)。" ,
112- "generator_build" : "必须先完成 validator_build,并让 validator 达到可用状态 。" ,
113- "stress_test_run" : "必须先完成 validator_build 和 generator_build,然后再进行 stress_test_run。" ,
114- "checker_build" : "必须先通过 stress_test_run,再构建 checker。" ,
115- "problem_generate_tests" : "必须先通过 stress_test_run,才能生成最终测试数据。" ,
116- "problem_pack_polygon" : "必须先生成最终测试数据,再进行打包。" ,
137+ "generator_build" : "必须先完成 validator_build,并且 validator accuracy >= 0.9 。" ,
138+ "stress_test_run" : "必须先完成 validator_build(accuracy >= 0.9) 和 generator_build,然后再进行 stress_test_run。" ,
139+ "checker_build" : "必须先通过 stress_test_run(completed_rounds == total_rounds) ,再构建 checker。" ,
140+ "problem_generate_tests" : "必须先通过 stress_test_run(completed_rounds == total_rounds) ,才能生成最终测试数据。" ,
141+ "problem_pack_polygon" : "必须先生成最终测试数据,并且生成数量 > 0, 再进行打包。" ,
117142 }
118143
119144 tool_input = payload .get ("tool_input" , {})
@@ -130,12 +155,18 @@ def pre_tool(payload: dict[str, Any]) -> int:
130155 deny (reasons ["validator_build" ])
131156 return 0
132157
133- if short_name == "generator_build" and not state ["validator_ready" ]:
158+ if short_name == "generator_build" and not (
159+ state ["validator_ready" ] and (state .get ("validator_accuracy" ) is None or state .get ("validator_accuracy" , 0 ) >= 0.9 )
160+ ):
134161 deny (reasons ["generator_build" ])
135162 return 0
136163
137164 if short_name == "stress_test_run" and not (
138- state ["sol_built" ] and state ["brute_built" ] and state ["validator_ready" ] and state ["generator_built" ]
165+ state ["sol_built" ]
166+ and state ["brute_built" ]
167+ and state ["validator_ready" ]
168+ and state .get ("validator_accuracy" , 0 ) >= 0.9
169+ and state ["generator_built" ]
139170 ):
140171 deny (reasons ["stress_test_run" ])
141172 return 0
@@ -148,7 +179,9 @@ def pre_tool(payload: dict[str, Any]) -> int:
148179 deny (reasons ["problem_generate_tests" ])
149180 return 0
150181
151- if short_name == "problem_pack_polygon" and not state ["tests_generated" ]:
182+ if short_name == "problem_pack_polygon" and not (
183+ state ["tests_generated" ] and state .get ("generated_test_count" , 0 ) > 0
184+ ):
152185 deny (reasons ["problem_pack_polygon" ])
153186 return 0
154187
@@ -176,16 +209,23 @@ def post_tool(payload: dict[str, Any]) -> int:
176209 elif solution_type == "brute" :
177210 state ["brute_built" ] = True
178211 elif short_name == "validator_build" :
179- state ["validator_ready" ] = data .get ("accuracy" , 1.0 ) >= 0.9
212+ accuracy = data .get ("accuracy" )
213+ state ["validator_accuracy" ] = accuracy
214+ state ["validator_ready" ] = accuracy is None or accuracy >= 0.9
180215 elif short_name == "generator_build" :
181216 state ["generator_built" ] = True
182217 elif short_name == "stress_test_run" :
218+ state ["stress_completed_rounds" ] = data .get ("completed_rounds" , 0 )
219+ state ["stress_total_rounds" ] = data .get ("total_rounds" , 0 )
183220 state ["stress_passed" ] = data .get ("completed_rounds" ) == data .get ("total_rounds" )
184221 elif short_name == "checker_build" :
185- state ["checker_ready" ] = data .get ("accuracy" , 1.0 ) >= 0.9
222+ accuracy = data .get ("accuracy" )
223+ state ["checker_accuracy" ] = accuracy
224+ state ["checker_ready" ] = accuracy is None or accuracy >= 0.9
186225 elif short_name == "problem_generate_tests" :
187226 generated_tests = data .get ("generated_tests" , [])
188227 state ["tests_generated" ] = bool (generated_tests )
228+ state ["generated_test_count" ] = len (generated_tests )
189229 elif short_name == "problem_pack_polygon" :
190230 state ["packaged" ] = True
191231
@@ -194,13 +234,26 @@ def post_tool(payload: dict[str, Any]) -> int:
194234
195235
196236def session_start () -> int :
197- reminder = (
198- "AutoCode plugin active. Enforce this workflow: "
237+ additional_context = (
238+ "AutoCode plugin active. Enforce this workflow with quality gates : "
199239 "problem_create -> solution_build(sol) -> solution_build(brute) -> "
200- "validator_build -> generator_build -> stress_test_run -> "
201- "checker_build(if needed) -> problem_generate_tests -> problem_pack_polygon."
240+ "validator_build(accuracy >= 0.9) -> generator_build -> "
241+ "stress_test_run(completed_rounds == total_rounds) -> "
242+ "checker_build if needed (accuracy >= 0.9) -> "
243+ "problem_generate_tests(generated_test_count > 0) -> problem_pack_polygon. "
244+ "If a hook blocks a step, complete the missing prerequisite instead of retrying blindly."
245+ )
246+ print (
247+ json .dumps (
248+ {
249+ "hookSpecificOutput" : {
250+ "hookEventName" : "SessionStart" ,
251+ "additionalContext" : additional_context ,
252+ }
253+ },
254+ ensure_ascii = False ,
255+ )
202256 )
203- print (reminder )
204257 return 0
205258
206259
0 commit comments