Skip to content

Commit 1745c25

Browse files
committed
feat: strengthen workflow guard quality gates
- add structured SessionStart additionalContext output - enforce validator, stress-test, and test-generation quality thresholds
1 parent 551ea74 commit 1745c25

File tree

1 file changed

+68
-15
lines changed

1 file changed

+68
-15
lines changed

hooks/workflow_guard.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,15 @@ def infer_state(problem_dir: str) -> dict[str, Any]:
4444
"sol_built": (root / "solutions" / "sol.cpp").exists() or any(root.glob("solutions/sol.*")),
4545
"brute_built": (root / "solutions" / "brute.cpp").exists() or any(root.glob("solutions/brute.*")),
4646
"validator_ready": (root / "files" / "val.cpp").exists() or any(root.glob("files/val.*")),
47+
"validator_accuracy": None,
4748
"generator_built": (root / "files" / "gen.cpp").exists() or any(root.glob("files/gen.*")),
4849
"stress_passed": False,
50+
"stress_completed_rounds": 0,
51+
"stress_total_rounds": 0,
4952
"checker_ready": (root / "files" / "checker.cpp").exists() or any(root.glob("files/checker.*")),
53+
"checker_accuracy": None,
5054
"tests_generated": any((root / "tests").glob("*.in")) if (root / "tests").exists() else False,
55+
"generated_test_count": len(list((root / "tests").glob("*.in"))) if (root / "tests").exists() else 0,
5156
"packaged": (root / "problem.xml").exists(),
5257
}
5358

@@ -94,6 +99,26 @@ def deny(reason: str) -> None:
9499
print(json.dumps(output, ensure_ascii=False))
95100

96101

102+
def quality_summary(state: dict[str, Any]) -> str:
103+
return (
104+
"Workflow state: "
105+
f"created={state['created']}, "
106+
f"sol_built={state['sol_built']}, "
107+
f"brute_built={state['brute_built']}, "
108+
f"validator_ready={state['validator_ready']}, "
109+
f"validator_accuracy={state.get('validator_accuracy')}, "
110+
f"generator_built={state['generator_built']}, "
111+
f"stress_passed={state['stress_passed']}, "
112+
f"stress_completed_rounds={state.get('stress_completed_rounds', 0)}, "
113+
f"stress_total_rounds={state.get('stress_total_rounds', 0)}, "
114+
f"checker_ready={state['checker_ready']}, "
115+
f"checker_accuracy={state.get('checker_accuracy')}, "
116+
f"tests_generated={state['tests_generated']}, "
117+
f"generated_test_count={state.get('generated_test_count', 0)}, "
118+
f"packaged={state['packaged']}."
119+
)
120+
121+
97122
def pre_tool(payload: dict[str, Any]) -> int:
98123
short_name = tool_short_name(payload.get("tool_name", ""))
99124
problem_dir = get_problem_dir(payload)
@@ -109,11 +134,11 @@ def pre_tool(payload: dict[str, Any]) -> int:
109134
"solution_build_brute": "必须先构建标准解 sol,再构建 brute。",
110135
"solution_build": "必须先运行 problem_create 创建题目目录。",
111136
"validator_build": "必须先完成 problem_create、solution_build(sol)、solution_build(brute)。",
112-
"generator_build": "必须先完成 validator_build,并让 validator 达到可用状态。",
113-
"stress_test_run": "必须先完成 validator_build 和 generator_build,然后再进行 stress_test_run。",
114-
"checker_build": "必须先通过 stress_test_run,再构建 checker。",
115-
"problem_generate_tests": "必须先通过 stress_test_run,才能生成最终测试数据。",
116-
"problem_pack_polygon": "必须先生成最终测试数据,再进行打包。",
137+
"generator_build": "必须先完成 validator_build,并且 validator accuracy >= 0.9。",
138+
"stress_test_run": "必须先完成 validator_build(accuracy >= 0.9) 和 generator_build,然后再进行 stress_test_run。",
139+
"checker_build": "必须先通过 stress_test_run(completed_rounds == total_rounds),再构建 checker。",
140+
"problem_generate_tests": "必须先通过 stress_test_run(completed_rounds == total_rounds),才能生成最终测试数据。",
141+
"problem_pack_polygon": "必须先生成最终测试数据,并且生成数量 > 0,再进行打包。",
117142
}
118143

119144
tool_input = payload.get("tool_input", {})
@@ -130,12 +155,18 @@ def pre_tool(payload: dict[str, Any]) -> int:
130155
deny(reasons["validator_build"])
131156
return 0
132157

133-
if short_name == "generator_build" and not state["validator_ready"]:
158+
if short_name == "generator_build" and not (
159+
state["validator_ready"] and (state.get("validator_accuracy") is None or state.get("validator_accuracy", 0) >= 0.9)
160+
):
134161
deny(reasons["generator_build"])
135162
return 0
136163

137164
if short_name == "stress_test_run" and not (
138-
state["sol_built"] and state["brute_built"] and state["validator_ready"] and state["generator_built"]
165+
state["sol_built"]
166+
and state["brute_built"]
167+
and state["validator_ready"]
168+
and state.get("validator_accuracy", 0) >= 0.9
169+
and state["generator_built"]
139170
):
140171
deny(reasons["stress_test_run"])
141172
return 0
@@ -148,7 +179,9 @@ def pre_tool(payload: dict[str, Any]) -> int:
148179
deny(reasons["problem_generate_tests"])
149180
return 0
150181

151-
if short_name == "problem_pack_polygon" and not state["tests_generated"]:
182+
if short_name == "problem_pack_polygon" and not (
183+
state["tests_generated"] and state.get("generated_test_count", 0) > 0
184+
):
152185
deny(reasons["problem_pack_polygon"])
153186
return 0
154187

@@ -176,16 +209,23 @@ def post_tool(payload: dict[str, Any]) -> int:
176209
elif solution_type == "brute":
177210
state["brute_built"] = True
178211
elif short_name == "validator_build":
179-
state["validator_ready"] = data.get("accuracy", 1.0) >= 0.9
212+
accuracy = data.get("accuracy")
213+
state["validator_accuracy"] = accuracy
214+
state["validator_ready"] = accuracy is None or accuracy >= 0.9
180215
elif short_name == "generator_build":
181216
state["generator_built"] = True
182217
elif short_name == "stress_test_run":
218+
state["stress_completed_rounds"] = data.get("completed_rounds", 0)
219+
state["stress_total_rounds"] = data.get("total_rounds", 0)
183220
state["stress_passed"] = data.get("completed_rounds") == data.get("total_rounds")
184221
elif short_name == "checker_build":
185-
state["checker_ready"] = data.get("accuracy", 1.0) >= 0.9
222+
accuracy = data.get("accuracy")
223+
state["checker_accuracy"] = accuracy
224+
state["checker_ready"] = accuracy is None or accuracy >= 0.9
186225
elif short_name == "problem_generate_tests":
187226
generated_tests = data.get("generated_tests", [])
188227
state["tests_generated"] = bool(generated_tests)
228+
state["generated_test_count"] = len(generated_tests)
189229
elif short_name == "problem_pack_polygon":
190230
state["packaged"] = True
191231

@@ -194,13 +234,26 @@ def post_tool(payload: dict[str, Any]) -> int:
194234

195235

196236
def session_start() -> int:
197-
reminder = (
198-
"AutoCode plugin active. Enforce this workflow: "
237+
additional_context = (
238+
"AutoCode plugin active. Enforce this workflow with quality gates: "
199239
"problem_create -> solution_build(sol) -> solution_build(brute) -> "
200-
"validator_build -> generator_build -> stress_test_run -> "
201-
"checker_build(if needed) -> problem_generate_tests -> problem_pack_polygon."
240+
"validator_build(accuracy >= 0.9) -> generator_build -> "
241+
"stress_test_run(completed_rounds == total_rounds) -> "
242+
"checker_build if needed (accuracy >= 0.9) -> "
243+
"problem_generate_tests(generated_test_count > 0) -> problem_pack_polygon. "
244+
"If a hook blocks a step, complete the missing prerequisite instead of retrying blindly."
245+
)
246+
print(
247+
json.dumps(
248+
{
249+
"hookSpecificOutput": {
250+
"hookEventName": "SessionStart",
251+
"additionalContext": additional_context,
252+
}
253+
},
254+
ensure_ascii=False,
255+
)
202256
)
203-
print(reminder)
204257
return 0
205258

206259

0 commit comments

Comments
 (0)