diff --git a/crates/codegen/src/ir.rs b/crates/codegen/src/ir.rs index 745726bf39b..de5781ab647 100644 --- a/crates/codegen/src/ir.rs +++ b/crates/codegen/src/ir.rs @@ -1219,7 +1219,10 @@ pub(crate) fn label_exception_targets(blocks: &mut [Block]) { preserve_lasti, }); } else if is_pop { - debug_assert!(!stack.is_empty(), "POP_BLOCK with empty except stack at block {bi} instruction {i}"); + debug_assert!( + !stack.is_empty(), + "POP_BLOCK with empty except stack at block {bi} instruction {i}" + ); stack.pop(); // POP_BLOCK → NOP blocks[bi].instructions[i].instr = Instruction::Nop.into(); diff --git a/scripts/update_lib/cmd_auto_mark.py b/scripts/update_lib/cmd_auto_mark.py index 8b4dabd94bf..a62b2795bba 100644 --- a/scripts/update_lib/cmd_auto_mark.py +++ b/scripts/update_lib/cmd_auto_mark.py @@ -250,6 +250,99 @@ def path_to_test_parts(path: str) -> list[str]: return parts[-2:] +def _expand_stripped_to_children( + contents: str, + stripped_tests: set[tuple[str, str]], + all_failing_tests: set[tuple[str, str]], +) -> set[tuple[str, str]]: + """Find child-class failures that correspond to stripped parent-class markers. + + When ``strip_reasonless_expected_failures`` removes a marker from a parent + (mixin) class, test failures are reported against the concrete subclasses, + not the parent itself. This function maps those child failures back so + they get re-marked (and later consolidated to the parent by + ``_consolidate_to_parent``). + + Returns the set of ``(class, method)`` pairs from *all_failing_tests* that + should be re-marked. + """ + # Direct matches (stripped test itself is a concrete TestCase) + result = stripped_tests & all_failing_tests + + unmatched = stripped_tests - all_failing_tests + if not unmatched: + return result + + tree = ast.parse(contents) + class_bases, class_methods = _build_inheritance_info(tree) + + for parent_cls, method_name in unmatched: + if method_name not in class_methods.get(parent_cls, set()): + continue + for cls in _find_all_inheritors( + parent_cls, method_name, class_bases, class_methods + ): + if (cls, method_name) in all_failing_tests: + result.add((cls, method_name)) + + return result + + +def _consolidate_to_parent( + contents: str, + failing_tests: set[tuple[str, str]], + error_messages: dict[tuple[str, str], str] | None = None, +) -> tuple[set[tuple[str, str]], dict[tuple[str, str], str] | None]: + """Move failures to the parent class when ALL inheritors fail. + + If every concrete subclass that inherits a method from a parent class + appears in *failing_tests*, replace those per-subclass entries with a + single entry on the parent. This avoids creating redundant super-call + overrides in every child. + + Returns: + (consolidated_failing_tests, consolidated_error_messages) + """ + tree = ast.parse(contents) + class_bases, class_methods = _build_inheritance_info(tree) + + # Group by (defining_parent, method) → set of failing children + from collections import defaultdict + + groups: dict[tuple[str, str], set[str]] = defaultdict(set) + for class_name, method_name in failing_tests: + defining = _find_method_definition( + class_name, method_name, class_bases, class_methods + ) + if defining and defining != class_name: + groups[(defining, method_name)].add(class_name) + + if not groups: + return failing_tests, error_messages + + result = set(failing_tests) + new_error_messages = dict(error_messages) if error_messages else {} + + for (parent, method_name), failing_children in groups.items(): + all_inheritors = _find_all_inheritors( + parent, method_name, class_bases, class_methods + ) + + if all_inheritors and failing_children >= all_inheritors: + # All inheritors fail → mark on parent instead + children_keys = {(child, method_name) for child in failing_children} + result -= children_keys + result.add((parent, method_name)) + # Pick any child's error message for the parent + if new_error_messages: + for child in failing_children: + msg = new_error_messages.pop((child, method_name), "") + if msg: + new_error_messages[(parent, method_name)] = msg + + return result, new_error_messages or error_messages + + def build_patches( test_parts_set: set[tuple[str, str]], error_messages: dict[tuple[str, str], str] | None = None, @@ -293,6 +386,24 @@ def _is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bo return True +def _method_removal_range( + func_node: ast.FunctionDef | ast.AsyncFunctionDef, lines: list[str] +) -> range: + """Line range covering an entire method including decorators and a preceding COMMENT line.""" + first = ( + func_node.decorator_list[0].lineno - 1 + if func_node.decorator_list + else func_node.lineno - 1 + ) + if ( + first > 0 + and lines[first - 1].strip().startswith("#") + and COMMENT in lines[first - 1] + ): + first -= 1 + return range(first, func_node.end_lineno) + + def _build_inheritance_info(tree: ast.Module) -> tuple[dict, dict]: """ Build inheritance information from AST. @@ -348,6 +459,20 @@ def _find_method_definition( return None +def _find_all_inheritors( + parent: str, method_name: str, class_bases: dict, class_methods: dict +) -> set[str]: + """Find all classes that inherit *method_name* from *parent* (not overriding it).""" + return { + cls + for cls in class_bases + if cls != parent + and method_name not in class_methods.get(cls, set()) + and _find_method_definition(cls, method_name, class_bases, class_methods) + == parent + } + + def remove_expected_failures( contents: str, tests_to_remove: set[tuple[str, str]] ) -> str: @@ -383,15 +508,7 @@ def remove_expected_failures( remove_entire_method = _is_super_call_only(item) if remove_entire_method: - first_line = item.lineno - 1 - if item.decorator_list: - first_line = item.decorator_list[0].lineno - 1 - if first_line > 0: - prev_line = lines[first_line - 1].strip() - if prev_line.startswith("#") and COMMENT in prev_line: - first_line -= 1 - for i in range(first_line, item.end_lineno): - lines_to_remove.add(i) + lines_to_remove.update(_method_removal_range(item, lines)) else: for dec in item.decorator_list: dec_line = dec.lineno - 1 @@ -406,11 +523,18 @@ def remove_expected_failures( and lines[dec_line - 1].strip().startswith("#") and COMMENT in lines[dec_line - 1] ) + has_comment_after = ( + dec_line + 1 < len(lines) + and lines[dec_line + 1].strip().startswith("#") + and COMMENT not in lines[dec_line + 1] + ) if has_comment_on_line or has_comment_before: lines_to_remove.add(dec_line) if has_comment_before: lines_to_remove.add(dec_line - 1) + if has_comment_after and has_comment_on_line: + lines_to_remove.add(dec_line + 1) for line_idx in sorted(lines_to_remove, reverse=True): del lines[line_idx] @@ -481,12 +605,98 @@ def apply_test_changes( contents = remove_expected_failures(contents, unexpected_successes) if failing_tests: + failing_tests, error_messages = _consolidate_to_parent( + contents, failing_tests, error_messages + ) patches = build_patches(failing_tests, error_messages) contents = apply_patches(contents, patches) return contents +def strip_reasonless_expected_failures( + contents: str, +) -> tuple[str, set[tuple[str, str]]]: + """Strip @expectedFailure decorators that have no failure reason. + + Markers like ``@unittest.expectedFailure # TODO: RUSTPYTHON`` (without a + reason after the semicolon) are removed so the tests fail normally during + the next test run and error messages can be captured. + + Returns: + (modified_contents, stripped_tests) where stripped_tests is a set of + (class_name, method_name) tuples whose markers were removed. + """ + tree = ast.parse(contents) + lines = contents.splitlines() + stripped_tests: set[tuple[str, str]] = set() + lines_to_remove: set[int] = set() + + for node in ast.walk(tree): + if not isinstance(node, ast.ClassDef): + continue + for item in node.body: + if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + for dec in item.decorator_list: + dec_line = dec.lineno - 1 + line_content = lines[dec_line] + + if "expectedFailure" not in line_content: + continue + + has_comment_on_line = COMMENT in line_content + has_comment_before = ( + dec_line > 0 + and lines[dec_line - 1].strip().startswith("#") + and COMMENT in lines[dec_line - 1] + ) + + if not has_comment_on_line and not has_comment_before: + continue # not our marker + + # Check if there's a reason (on either the decorator or before) + for check_line in ( + line_content, + lines[dec_line - 1] if has_comment_before else "", + ): + match = re.search(rf"{COMMENT}(.*)", check_line) + if match and match.group(1).strip(";:, "): + break # has a reason, keep it + else: + # No reason found — strip this decorator + stripped_tests.add((node.name, item.name)) + + if _is_super_call_only(item): + # Remove entire super-call override (the method + # exists only to apply the decorator; without it + # the override is pointless and blocks parent + # consolidation) + lines_to_remove.update(_method_removal_range(item, lines)) + else: + lines_to_remove.add(dec_line) + + if has_comment_before: + lines_to_remove.add(dec_line - 1) + + # Also remove a reason-comment on the line after (old format) + if ( + has_comment_on_line + and dec_line + 1 < len(lines) + and lines[dec_line + 1].strip().startswith("#") + and COMMENT not in lines[dec_line + 1] + ): + lines_to_remove.add(dec_line + 1) + + if not lines_to_remove: + return contents, stripped_tests + + for idx in sorted(lines_to_remove, reverse=True): + del lines[idx] + + return "\n".join(lines) + "\n" if lines else "", stripped_tests + + def extract_test_methods(contents: str) -> set[tuple[str, str]]: """ Extract all test method names from file contents. @@ -529,6 +739,13 @@ def auto_mark_file( if not test_path.exists(): raise FileNotFoundError(f"File not found: {test_path}") + # Strip reason-less markers so those tests fail normally and we capture + # their error messages during the test run. + contents = test_path.read_text(encoding="utf-8") + contents, stripped_tests = strip_reasonless_expected_failures(contents) + if stripped_tests: + test_path.write_text(contents, encoding="utf-8") + test_name = get_test_module_name(test_path) if verbose: print(f"Running test: {test_name}") @@ -559,6 +776,13 @@ def auto_mark_file( else: failing_tests = set() + # Re-mark stripped tests that still fail (to restore markers with reasons). + # Uses inheritance expansion: if a parent marker was stripped, child + # failures are included so _consolidate_to_parent can re-mark the parent. + failing_tests |= _expand_stripped_to_children( + contents, stripped_tests, all_failing_tests + ) + regressions = all_failing_tests - failing_tests if verbose: @@ -626,6 +850,19 @@ def auto_mark_directory( if not test_dir.is_dir(): raise ValueError(f"Not a directory: {test_dir}") + # Get all .py files in directory + test_files = sorted(test_dir.glob("**/*.py")) + + # Strip reason-less markers from ALL files before running tests so those + # tests fail normally and we capture their error messages. + stripped_per_file: dict[pathlib.Path, set[tuple[str, str]]] = {} + for test_file in test_files: + contents = test_file.read_text(encoding="utf-8") + contents, stripped = strip_reasonless_expected_failures(contents) + if stripped: + test_file.write_text(contents, encoding="utf-8") + stripped_per_file[test_file] = stripped + test_name = get_test_module_name(test_dir) if verbose: print(f"Running test: {test_name}") @@ -644,9 +881,6 @@ def auto_mark_directory( total_regressions = 0 all_regressions: list[tuple[str, str, str, str]] = [] - # Get all .py files in directory - test_files = sorted(test_dir.glob("**/*.py")) - for test_file in test_files: # Get module prefix for this file (e.g., "test_inspect.test_inspect") module_prefix = get_test_module_name(test_file) @@ -671,6 +905,15 @@ def auto_mark_directory( else: failing_tests = set() + # Re-mark stripped tests that still fail (restore markers with reasons). + # Uses inheritance expansion for parent→child mapping. + stripped = stripped_per_file.get(test_file, set()) + if stripped: + file_contents = test_file.read_text(encoding="utf-8") + failing_tests |= _expand_stripped_to_children( + file_contents, stripped, all_failing_tests + ) + regressions = all_failing_tests - failing_tests if failing_tests or unexpected_successes: diff --git a/scripts/update_lib/tests/test_auto_mark.py b/scripts/update_lib/tests/test_auto_mark.py index f4633ada6d3..15a80e49e44 100644 --- a/scripts/update_lib/tests/test_auto_mark.py +++ b/scripts/update_lib/tests/test_auto_mark.py @@ -1,11 +1,13 @@ """Tests for auto_mark.py - test result parsing and auto-marking.""" +import ast import subprocess import unittest from update_lib.cmd_auto_mark import ( Test, TestResult, + _expand_stripped_to_children, _is_super_call_only, apply_test_changes, collect_test_changes, @@ -13,65 +15,75 @@ parse_results, path_to_test_parts, remove_expected_failures, + strip_reasonless_expected_failures, ) from update_lib.patch_spec import COMMENT -class TestParseResults(unittest.TestCase): - """Tests for parse_results function.""" +def _make_result(stdout: str) -> subprocess.CompletedProcess: + return subprocess.CompletedProcess( + args=["test"], returncode=0, stdout=stdout, stderr="" + ) - def _make_result(self, stdout: str) -> subprocess.CompletedProcess: - """Create a mock CompletedProcess.""" - return subprocess.CompletedProcess( - args=["test"], - returncode=0, - stdout=stdout, - stderr="", - ) - def test_parse_failing_test(self): - """Test parsing a failing test.""" - stdout = """ -Run 1 tests sequentially -test_foo (test.test_example.TestClass.test_foo) ... FAIL ------------ +# -- fixtures shared across inheritance-aware tests -- + +BASE_TWO_CHILDREN = """import unittest + +class Base: + def test_foo(self): + pass + +class ChildA(Base, unittest.TestCase): + pass + +class ChildB(Base, unittest.TestCase): + pass """ - result = parse_results(self._make_result(stdout)) - self.assertEqual(len(result.tests), 1) - self.assertEqual(result.tests[0].name, "test_foo") - self.assertEqual(result.tests[0].path, "test.test_example.TestClass.test_foo") - self.assertEqual(result.tests[0].result, "fail") - - def test_parse_error_test(self): - """Test parsing an error test.""" - stdout = """ -Run 1 tests sequentially -test_bar (test.test_example.TestClass.test_bar) ... ERROR ------------ + +BASE_TWO_CHILDREN_ONE_OVERRIDE = """import unittest + +class Base: + def test_foo(self): + pass + +class ChildA(Base, unittest.TestCase): + pass + +class ChildB(Base, unittest.TestCase): + def test_foo(self): + # own implementation + pass """ - result = parse_results(self._make_result(stdout)) - self.assertEqual(len(result.tests), 1) - self.assertEqual(result.tests[0].result, "error") - def test_parse_ok_test_ignored(self): - """Test that passing tests are ignored.""" - stdout = """ -Run 1 tests sequentially -test_foo (test.test_example.TestClass.test_foo) ... ok + +class TestParseResults(unittest.TestCase): + """Tests for parse_results function.""" + + def test_parse_fail_and_error(self): + """FAIL and ERROR are collected; ok is ignored.""" + stdout = """\ +Run 3 tests sequentially +test_one (test.test_example.TestA.test_one) ... FAIL +test_two (test.test_example.TestA.test_two) ... ok +test_three (test.test_example.TestB.test_three) ... ERROR ----------- """ - result = parse_results(self._make_result(stdout)) - self.assertEqual(len(result.tests), 0) + result = parse_results(_make_result(stdout)) + self.assertEqual(len(result.tests), 2) + by_name = {t.name: t for t in result.tests} + self.assertEqual(by_name["test_one"].path, "test.test_example.TestA.test_one") + self.assertEqual(by_name["test_one"].result, "fail") + self.assertEqual(by_name["test_three"].result, "error") def test_parse_unexpected_success(self): - """Test parsing unexpected success.""" - stdout = """ + stdout = """\ Run 1 tests sequentially test_foo (test.test_example.TestClass.test_foo) ... unexpected success ----------- UNEXPECTED SUCCESS: test_foo (test.test_example.TestClass.test_foo) """ - result = parse_results(self._make_result(stdout)) + result = parse_results(_make_result(stdout)) self.assertEqual(len(result.unexpected_successes), 1) self.assertEqual(result.unexpected_successes[0].name, "test_foo") self.assertEqual( @@ -79,30 +91,15 @@ def test_parse_unexpected_success(self): ) def test_parse_tests_result(self): - """Test parsing tests result line.""" - stdout = """ -== Tests result: FAILURE == -""" - result = parse_results(self._make_result(stdout)) + result = parse_results(_make_result("== Tests result: FAILURE ==\n")) self.assertEqual(result.tests_result, "FAILURE") - def test_parse_multiple_tests(self): - """Test parsing multiple test results.""" - stdout = """ -Run 3 tests sequentially -test_one (test.test_example.TestA.test_one) ... FAIL -test_two (test.test_example.TestA.test_two) ... ok -test_three (test.test_example.TestB.test_three) ... ERROR ------------ -""" - result = parse_results(self._make_result(stdout)) - self.assertEqual(len(result.tests), 2) # Only FAIL and ERROR - - def test_parse_error_message(self): - """Test parsing error message from traceback.""" - stdout = """ -Run 1 tests sequentially + def test_parse_error_messages(self): + """Single and multiple error messages are parsed from tracebacks.""" + stdout = """\ +Run 2 tests sequentially test_foo (test.test_example.TestClass.test_foo) ... FAIL +test_bar (test.test_example.TestClass.test_bar) ... ERROR ----------- ====================================================================== FAIL: test_foo (test.test_example.TestClass.test_foo) @@ -112,19 +109,23 @@ def test_parse_error_message(self): self.assertEqual(1, 2) AssertionError: 1 != 2 +====================================================================== +ERROR: test_bar (test.test_example.TestClass.test_bar) +---------------------------------------------------------------------- +Traceback (most recent call last): + File "test.py", line 20, in test_bar + raise ValueError("oops") +ValueError: oops + ====================================================================== """ - result = parse_results(self._make_result(stdout)) - self.assertEqual(len(result.tests), 1) - self.assertEqual(result.tests[0].error_message, "AssertionError: 1 != 2") + result = parse_results(_make_result(stdout)) + by_name = {t.name: t for t in result.tests} + self.assertEqual(by_name["test_foo"].error_message, "AssertionError: 1 != 2") + self.assertEqual(by_name["test_bar"].error_message, "ValueError: oops") def test_parse_directory_test_multiple_submodules(self): - """Test parsing directory test output with multiple submodules. - - When running a directory test (e.g., test_asyncio), the output contains - multiple submodules separated by '------' lines. Failures in submodules - after the first one must still be detected. - """ + """Failures across submodule boundaries are all detected.""" stdout = """\ Run 3 tests sequentially 0:00:00 [ 1/3] test_asyncio.test_buffered_proto @@ -153,25 +154,15 @@ def test_parse_directory_test_multiple_submodules(self): == Tests result: FAILURE == """ - result = parse_results(self._make_result(stdout)) + result = parse_results(_make_result(stdout)) self.assertEqual(len(result.tests), 2) names = {t.name for t in result.tests} self.assertIn("test_create", names) self.assertIn("test_gather", names) - # Verify results - test_create = next(t for t in result.tests if t.name == "test_create") - test_gather = next(t for t in result.tests if t.name == "test_gather") - self.assertEqual(test_create.result, "fail") - self.assertEqual(test_gather.result, "error") self.assertEqual(result.tests_result, "FAILURE") def test_parse_multiline_test_with_docstring(self): - """Test parsing tests where docstring appears on a separate line. - - Some tests have docstrings that cause the output to span two lines: - test_name (path) - docstring ... ERROR - """ + """Two-line output (test_name + docstring ... RESULT) is handled.""" stdout = """\ Run 3 tests sequentially test_ok (test.test_example.TestClass.test_ok) ... ok @@ -179,7 +170,7 @@ def test_parse_multiline_test_with_docstring(self): Test that something works ... ERROR test_normal_fail (test.test_example.TestClass.test_normal_fail) ... FAIL """ - result = parse_results(self._make_result(stdout)) + result = parse_results(_make_result(stdout)) self.assertEqual(len(result.tests), 2) names = {t.name for t in result.tests} self.assertIn("test_with_doc", names) @@ -188,82 +179,51 @@ def test_parse_multiline_test_with_docstring(self): self.assertEqual(test_doc.path, "test.test_example.TestClass.test_with_doc") self.assertEqual(test_doc.result, "error") - def test_parse_multiple_error_messages(self): - """Test parsing multiple error messages.""" - stdout = """ -Run 2 tests sequentially -test_foo (test.test_example.TestClass.test_foo) ... FAIL -test_bar (test.test_example.TestClass.test_bar) ... ERROR ------------ -====================================================================== -FAIL: test_foo (test.test_example.TestClass.test_foo) ----------------------------------------------------------------------- -Traceback (most recent call last): - File "test.py", line 10, in test_foo - self.assertEqual(1, 2) -AssertionError: 1 != 2 - -====================================================================== -ERROR: test_bar (test.test_example.TestClass.test_bar) ----------------------------------------------------------------------- -Traceback (most recent call last): - File "test.py", line 20, in test_bar - raise ValueError("oops") -ValueError: oops - -====================================================================== -""" - result = parse_results(self._make_result(stdout)) - self.assertEqual(len(result.tests), 2) - # Find tests by name - test_foo = next(t for t in result.tests if t.name == "test_foo") - test_bar = next(t for t in result.tests if t.name == "test_bar") - self.assertEqual(test_foo.error_message, "AssertionError: 1 != 2") - self.assertEqual(test_bar.error_message, "ValueError: oops") - class TestPathToTestParts(unittest.TestCase): - """Tests for path_to_test_parts function.""" - def test_simple_path(self): - """Test extracting parts from simple path.""" - parts = path_to_test_parts("test.test_foo.TestClass.test_method") - self.assertEqual(parts, ["TestClass", "test_method"]) + self.assertEqual( + path_to_test_parts("test.test_foo.TestClass.test_method"), + ["TestClass", "test_method"], + ) def test_nested_path(self): - """Test extracting parts from nested path.""" - parts = path_to_test_parts("test.test_foo.test_bar.TestClass.test_method") - self.assertEqual(parts, ["TestClass", "test_method"]) + self.assertEqual( + path_to_test_parts("test.test_foo.test_bar.TestClass.test_method"), + ["TestClass", "test_method"], + ) class TestCollectTestChanges(unittest.TestCase): - """Tests for collect_test_changes function.""" - - def test_collect_failing_tests(self): - """Test collecting failing tests.""" + def test_collect_failures_and_error_messages(self): + """Failures and error messages are collected; empty messages are omitted.""" results = TestResult() results.tests = [ Test( name="test_foo", path="test.test_example.TestClass.test_foo", result="fail", + error_message="AssertionError: 1 != 2", ), Test( name="test_bar", path="test.test_example.TestClass.test_bar", result="error", + error_message="", ), ] - failing, successes, error_messages = collect_test_changes(results) - self.assertEqual(len(failing), 2) - self.assertIn(("TestClass", "test_foo"), failing) - self.assertIn(("TestClass", "test_bar"), failing) - self.assertEqual(len(successes), 0) + self.assertEqual( + failing, {("TestClass", "test_foo"), ("TestClass", "test_bar")} + ) + self.assertEqual(successes, set()) + self.assertEqual(len(error_messages), 1) + self.assertEqual( + error_messages[("TestClass", "test_foo")], "AssertionError: 1 != 2" + ) def test_collect_unexpected_successes(self): - """Test collecting unexpected successes.""" results = TestResult() results.unexpected_successes = [ Test( @@ -272,86 +232,35 @@ def test_collect_unexpected_successes(self): result="unexpected_success", ), ] + _, successes, _ = collect_test_changes(results) + self.assertEqual(successes, {("TestClass", "test_foo")}) - failing, successes, error_messages = collect_test_changes(results) - - self.assertEqual(len(failing), 0) - self.assertEqual(len(successes), 1) - self.assertIn(("TestClass", "test_foo"), successes) - - def test_collect_with_module_prefix(self): - """Test collecting with module prefix filter.""" + def test_module_prefix_filtering(self): + """Prefix filters with both short and 'test.' prefix formats.""" results = TestResult() results.tests = [ Test(name="test_foo", path="test_a.TestClass.test_foo", result="fail"), - Test(name="test_bar", path="test_b.TestClass.test_bar", result="fail"), - ] - - failing, _, _ = collect_test_changes(results, module_prefix="test_a.") - - self.assertEqual(len(failing), 1) - self.assertIn(("TestClass", "test_foo"), failing) - - def test_collect_error_messages(self): - """Test collecting error messages.""" - results = TestResult() - results.tests = [ - Test( - name="test_foo", - path="test.test_example.TestClass.test_foo", - result="fail", - error_message="AssertionError: 1 != 2", - ), Test( name="test_bar", - path="test.test_example.TestClass.test_bar", - result="error", - error_message="", - ), - ] - - failing, successes, error_messages = collect_test_changes(results) - - self.assertEqual(len(error_messages), 1) - self.assertEqual( - error_messages[("TestClass", "test_foo")], "AssertionError: 1 != 2" - ) - - def test_collect_with_test_prefix_in_path(self): - """Test collecting with 'test.' prefix in path (like real test output).""" - results = TestResult() - results.tests = [ - Test( - name="test_foo", - path="test.test_dataclasses.TestCase.test_foo", + path="test.test_dataclasses.TestCase.test_bar", result="fail", ), Test( - name="test_bar", - path="test.test_other.TestOther.test_bar", + name="test_baz", + path="test.test_other.TestOther.test_baz", result="fail", ), ] + failing_a, _, _ = collect_test_changes(results, module_prefix="test_a.") + self.assertEqual(failing_a, {("TestClass", "test_foo")}) - # Filter with prefix that matches real test module path format - failing, _, _ = collect_test_changes( + failing_dc, _, _ = collect_test_changes( results, module_prefix="test.test_dataclasses." ) - - self.assertEqual(len(failing), 1) - self.assertIn(("TestCase", "test_foo"), failing) + self.assertEqual(failing_dc, {("TestCase", "test_bar")}) def test_collect_init_module_matching(self): - """Test that __init__.py tests match without __init__ in path. - - When test results come from a package's __init__.py, the path is like: - 'test.test_dataclasses.TestCase.test_foo' (no __init__) - - But module_prefix from get_test_module_name would be: - 'test_dataclasses.__init__' - - So we need to strip '.__init__' and add 'test.' prefix. - """ + """__init__.py tests match after stripping .__init__ from the prefix.""" results = TestResult() results.tests = [ Test( @@ -360,39 +269,18 @@ def test_collect_init_module_matching(self): result="fail", ), ] - - # Simulate the corrected prefix (after stripping .__init__ and adding test.) module_prefix = "test_dataclasses.__init__" if module_prefix.endswith(".__init__"): module_prefix = module_prefix[:-9] module_prefix = "test." + module_prefix + "." failing, _, _ = collect_test_changes(results, module_prefix=module_prefix) - - self.assertEqual(len(failing), 1) - self.assertIn(("TestCase", "test_field_repr"), failing) + self.assertEqual(failing, {("TestCase", "test_field_repr")}) class TestExtractTestMethods(unittest.TestCase): - """Tests for extract_test_methods function.""" - - def test_extract_simple(self): - """Test extracting test methods from simple class.""" - code = """ -class TestFoo(unittest.TestCase): - def test_one(self): - pass - - def test_two(self): - pass -""" - methods = extract_test_methods(code) - self.assertEqual(len(methods), 2) - self.assertIn(("TestFoo", "test_one"), methods) - self.assertIn(("TestFoo", "test_two"), methods) - - def test_extract_multiple_classes(self): - """Test extracting from multiple classes.""" + def test_extract_methods(self): + """Extracts from single and multiple classes.""" code = """ class TestA(unittest.TestCase): def test_a(self): @@ -403,22 +291,14 @@ def test_b(self): pass """ methods = extract_test_methods(code) - self.assertEqual(len(methods), 2) - self.assertIn(("TestA", "test_a"), methods) - self.assertIn(("TestB", "test_b"), methods) + self.assertEqual(methods, {("TestA", "test_a"), ("TestB", "test_b")}) def test_extract_syntax_error_returns_empty(self): - """Test that syntax error returns empty set.""" - code = "this is not valid python {" - methods = extract_test_methods(code) - self.assertEqual(methods, set()) + self.assertEqual(extract_test_methods("this is not valid python {"), set()) class TestRemoveExpectedFailures(unittest.TestCase): - """Tests for remove_expected_failures function.""" - - def test_remove_simple(self): - """Test removing simple expectedFailure decorator.""" + def test_remove_comment_before(self): code = f"""import unittest class TestFoo(unittest.TestCase): @@ -431,8 +311,7 @@ def test_one(self): self.assertNotIn("@unittest.expectedFailure", result) self.assertIn("def test_one(self):", result) - def test_remove_with_inline_comment(self): - """Test removing expectedFailure with inline comment.""" + def test_remove_inline_comment(self): code = f"""import unittest class TestFoo(unittest.TestCase): @@ -444,7 +323,7 @@ def test_one(self): self.assertNotIn("@unittest.expectedFailure", result) def test_remove_super_call_method(self): - """Test removing method that just calls super().""" + """Super-call-only override is removed entirely (sync).""" code = f"""import unittest class TestFoo(unittest.TestCase): @@ -456,8 +335,43 @@ def test_one(self): result = remove_expected_failures(code, {("TestFoo", "test_one")}) self.assertNotIn("def test_one", result) + def test_remove_async_super_call_override(self): + """Super-call-only override is removed entirely (async).""" + code = f"""import unittest + +class BaseTest: + async def test_async_one(self): + pass + +class TestChild(BaseTest, unittest.TestCase): + # {COMMENT} + @unittest.expectedFailure + async def test_async_one(self): + return await super().test_async_one() +""" + result = remove_expected_failures(code, {("TestChild", "test_async_one")}) + self.assertNotIn("return await super().test_async_one()", result) + self.assertNotIn("@unittest.expectedFailure", result) + self.assertIn("class TestChild", result) + self.assertIn("async def test_async_one(self):", result) + + def test_remove_with_comment_after(self): + """Reason comment on the line after the decorator is also removed.""" + code = f"""import unittest + +class TestFoo(unittest.TestCase): + @unittest.expectedFailure # {COMMENT} + # RuntimeError: something went wrong + def test_one(self): + pass +""" + result = remove_expected_failures(code, {("TestFoo", "test_one")}) + self.assertNotIn("@unittest.expectedFailure", result) + self.assertNotIn("RuntimeError: something went wrong", result) + self.assertIn("def test_one(self):", result) + def test_no_removal_without_comment(self): - """Test that decorators without COMMENT are not removed.""" + """Decorators without our COMMENT marker are left untouched.""" code = """import unittest class TestFoo(unittest.TestCase): @@ -466,29 +380,187 @@ def test_one(self): pass """ result = remove_expected_failures(code, {("TestFoo", "test_one")}) - # Should still have the decorator self.assertIn("@unittest.expectedFailure", result) -class TestApplyTestChanges(unittest.TestCase): - """Tests for apply_test_changes function.""" +class TestStripReasonlessExpectedFailures(unittest.TestCase): + def test_strip_reason_formats(self): + """Strips both inline-comment and comment-before formats when no reason.""" + for label, code in [ + ( + "inline", + f"""import unittest - def test_apply_failing_tests(self): - """Test applying expectedFailure to failing tests.""" +class TestFoo(unittest.TestCase): + @unittest.expectedFailure # {COMMENT} + def test_one(self): + pass +""", + ), + ( + "comment-before", + f"""import unittest + +class TestFoo(unittest.TestCase): + # {COMMENT} + @unittest.expectedFailure + def test_one(self): + pass +""", + ), + ]: + with self.subTest(label): + result, stripped = strip_reasonless_expected_failures(code) + self.assertNotIn("@unittest.expectedFailure", result) + self.assertIn("def test_one(self):", result) + self.assertEqual(stripped, {("TestFoo", "test_one")}) + + def test_keep_with_reason(self): + code = f"""import unittest + +class TestFoo(unittest.TestCase): + @unittest.expectedFailure # {COMMENT}; AssertionError: 1 != 2 + def test_one(self): + pass +""" + result, stripped = strip_reasonless_expected_failures(code) + self.assertIn("@unittest.expectedFailure", result) + self.assertEqual(stripped, set()) + + def test_strip_with_comment_after(self): + """Old-format reason comment on the next line is also removed.""" + code = f"""import unittest + +class TestFoo(unittest.TestCase): + @unittest.expectedFailure # {COMMENT} + # RuntimeError: something went wrong + def test_one(self): + pass +""" + result, stripped = strip_reasonless_expected_failures(code) + self.assertNotIn("RuntimeError", result) + self.assertIn("def test_one(self):", result) + self.assertEqual(stripped, {("TestFoo", "test_one")}) + + def test_strip_super_call_override(self): + """Super-call overrides are removed entirely (both comment formats).""" + for label, code in [ + ( + "comment-before", + f"""import unittest + +class _BaseTests: + def test_foo(self): + pass + +class TestChild(_BaseTests, unittest.TestCase): + # {COMMENT} + @unittest.expectedFailure + def test_foo(self): + return super().test_foo() +""", + ), + ( + "inline", + f"""import unittest + +class _BaseTests: + def test_foo(self): + pass + +class TestChild(_BaseTests, unittest.TestCase): + @unittest.expectedFailure # {COMMENT} + def test_foo(self): + return super().test_foo() +""", + ), + ]: + with self.subTest(label): + result, stripped = strip_reasonless_expected_failures(code) + self.assertNotIn("return super().test_foo()", result) + self.assertNotIn("@unittest.expectedFailure", result) + self.assertEqual(stripped, {("TestChild", "test_foo")}) + self.assertIn("class _BaseTests:", result) + + def test_no_strip_without_comment(self): + """Markers without our COMMENT are NOT stripped.""" + code = """import unittest + +class TestFoo(unittest.TestCase): + @unittest.expectedFailure + def test_one(self): + pass +""" + result, stripped = strip_reasonless_expected_failures(code) + self.assertIn("@unittest.expectedFailure", result) + self.assertEqual(stripped, set()) + + def test_mixed_with_and_without_reason(self): + code = f"""import unittest + +class TestFoo(unittest.TestCase): + @unittest.expectedFailure # {COMMENT} + def test_no_reason(self): + pass + + @unittest.expectedFailure # {COMMENT}; has a reason + def test_has_reason(self): + pass +""" + result, stripped = strip_reasonless_expected_failures(code) + self.assertEqual(stripped, {("TestFoo", "test_no_reason")}) + self.assertIn("has a reason", result) + self.assertEqual(result.count("@unittest.expectedFailure"), 1) + + +class TestExpandStrippedToChildren(unittest.TestCase): + def test_parent_to_children(self): + """Parent stripped → all/partial failing children returned.""" + stripped = {("Base", "test_foo")} + all_children = {("ChildA", "test_foo"), ("ChildB", "test_foo")} + + # All children fail + result = _expand_stripped_to_children(BASE_TWO_CHILDREN, stripped, all_children) + self.assertEqual(result, all_children) + + # Only one child fails + partial = {("ChildA", "test_foo")} + result = _expand_stripped_to_children(BASE_TWO_CHILDREN, stripped, partial) + self.assertEqual(result, partial) + + def test_direct_match(self): code = """import unittest class TestFoo(unittest.TestCase): def test_one(self): pass """ - failing = {("TestFoo", "test_one")} - result = apply_test_changes(code, failing, set()) + s = {("TestFoo", "test_one")} + self.assertEqual(_expand_stripped_to_children(code, s, s), s) + + def test_child_with_own_override_excluded(self): + stripped = {("Base", "test_foo")} + all_failing = {("ChildA", "test_foo"), ("ChildB", "test_foo")} + result = _expand_stripped_to_children( + BASE_TWO_CHILDREN_ONE_OVERRIDE, stripped, all_failing + ) + # ChildA inherits → included; ChildB has own method → excluded + self.assertEqual(result, {("ChildA", "test_foo")}) + + +class TestApplyTestChanges(unittest.TestCase): + def test_apply_failing_tests(self): + code = """import unittest +class TestFoo(unittest.TestCase): + def test_one(self): + pass +""" + result = apply_test_changes(code, {("TestFoo", "test_one")}, set()) self.assertIn("@unittest.expectedFailure", result) self.assertIn(COMMENT, result) def test_apply_removes_unexpected_success(self): - """Test removing expectedFailure from unexpected success.""" code = f"""import unittest class TestFoo(unittest.TestCase): @@ -497,14 +569,11 @@ class TestFoo(unittest.TestCase): def test_one(self): pass """ - successes = {("TestFoo", "test_one")} - result = apply_test_changes(code, set(), successes) - + result = apply_test_changes(code, set(), {("TestFoo", "test_one")}) self.assertNotIn("@unittest.expectedFailure", result) self.assertIn("def test_one(self):", result) def test_apply_both_changes(self): - """Test applying both failing tests and removing unexpected successes.""" code = f"""import unittest class TestFoo(unittest.TestCase): @@ -516,299 +585,166 @@ def test_one(self): def test_two(self): pass """ - failing = {("TestFoo", "test_one")} - successes = {("TestFoo", "test_two")} - result = apply_test_changes(code, failing, successes) - - # test_one should now have expectedFailure - self.assertIn("def test_one(self):", result) - # Only one expectedFailure decorator should remain (on test_one) + result = apply_test_changes( + code, {("TestFoo", "test_one")}, {("TestFoo", "test_two")} + ) self.assertEqual(result.count("@unittest.expectedFailure"), 1) def test_apply_with_error_message(self): - """Test applying expectedFailure with error message.""" code = """import unittest class TestFoo(unittest.TestCase): def test_one(self): pass """ - failing = {("TestFoo", "test_one")} - error_messages = {("TestFoo", "test_one"): "AssertionError: 1 != 2"} - result = apply_test_changes(code, failing, set(), error_messages) - - self.assertIn("@unittest.expectedFailure", result) + result = apply_test_changes( + code, + {("TestFoo", "test_one")}, + set(), + {("TestFoo", "test_one"): "AssertionError: 1 != 2"}, + ) self.assertIn("AssertionError: 1 != 2", result) self.assertIn(COMMENT, result) -class TestSmartAutoMarkFiltering(unittest.TestCase): - """Tests for smart auto-mark filtering logic (regression exclusion). - - The smart auto-mark feature: - - Marks NEW test failures (tests that didn't exist before) - - Does NOT mark regressions (existing tests that now fail) - """ - - def _filter_failures( - self, - all_failing_tests: set[tuple[str, str]], - original_methods: set[tuple[str, str]], - current_methods: set[tuple[str, str]], - ) -> tuple[set[tuple[str, str]], set[tuple[str, str]]]: - """Simulate the filtering logic from auto_mark_file(). - - Returns: - (failing_tests_to_mark, regressions) - """ - new_methods = current_methods - original_methods - failing_tests = {t for t in all_failing_tests if t in new_methods} - regressions = all_failing_tests - failing_tests - return failing_tests, regressions - - def test_new_tests_get_marked(self): - """Test that new failing tests are marked.""" - original_methods = {("TestFoo", "test_existing")} - current_methods = {("TestFoo", "test_existing"), ("TestFoo", "test_new")} - all_failing = {("TestFoo", "test_new")} - - to_mark, regressions = self._filter_failures( - all_failing, original_methods, current_methods - ) +class TestConsolidateToParent(unittest.TestCase): + def test_all_children_fail_marks_parent_with_message(self): + """All subclasses fail → marks parent; error message is transferred.""" + failing = {("ChildA", "test_foo"), ("ChildB", "test_foo")} + error_messages = {("ChildA", "test_foo"): "RuntimeError: boom"} + result = apply_test_changes(BASE_TWO_CHILDREN, failing, set(), error_messages) - self.assertEqual(to_mark, {("TestFoo", "test_new")}) - self.assertEqual(regressions, set()) + self.assertEqual(result.count("@unittest.expectedFailure"), 1) + self.assertNotIn("return super()", result) + self.assertIn("RuntimeError: boom", result) + + def test_partial_children_fail_marks_children(self): + result = apply_test_changes(BASE_TWO_CHILDREN, {("ChildA", "test_foo")}, set()) + self.assertIn("return super().test_foo()", result) + self.assertEqual(result.count("@unittest.expectedFailure"), 1) - def test_regressions_not_marked(self): - """Test that existing failing tests (regressions) are NOT marked.""" - original_methods = {("TestFoo", "test_existing")} - current_methods = {("TestFoo", "test_existing")} - all_failing = {("TestFoo", "test_existing")} + def test_child_with_own_override_not_consolidated(self): + failing = {("ChildA", "test_foo"), ("ChildB", "test_foo")} + result = apply_test_changes(BASE_TWO_CHILDREN_ONE_OVERRIDE, failing, set()) + self.assertEqual(result.count("@unittest.expectedFailure"), 2) - to_mark, regressions = self._filter_failures( - all_failing, original_methods, current_methods - ) + def test_strip_then_consolidate_restores_parent_marker(self): + """End-to-end: strip parent marker → child failures → re-mark on parent.""" + code = f"""import unittest - self.assertEqual(to_mark, set()) - self.assertEqual(regressions, {("TestFoo", "test_existing")}) - - def test_mixed_new_and_regression(self): - """Test with both new failures and regressions.""" - original_methods = {("TestFoo", "test_old1"), ("TestFoo", "test_old2")} - current_methods = { - ("TestFoo", "test_old1"), - ("TestFoo", "test_old2"), - ("TestFoo", "test_new1"), - ("TestFoo", "test_new2"), - } - # test_old1 regressed, test_new1 is a new failure - all_failing = {("TestFoo", "test_old1"), ("TestFoo", "test_new1")} +class _BaseTests: + @unittest.expectedFailure # {COMMENT} + def test_foo(self): + pass - to_mark, regressions = self._filter_failures( - all_failing, original_methods, current_methods - ) +class ChildA(_BaseTests, unittest.TestCase): + pass - self.assertEqual(to_mark, {("TestFoo", "test_new1")}) - self.assertEqual(regressions, {("TestFoo", "test_old1")}) +class ChildB(_BaseTests, unittest.TestCase): + pass +""" + stripped_code, stripped_tests = strip_reasonless_expected_failures(code) + self.assertEqual(stripped_tests, {("_BaseTests", "test_foo")}) - def test_multiple_classes(self): - """Test filtering across multiple classes.""" - original_methods = {("TestA", "test_a"), ("TestB", "test_b")} - current_methods = { - ("TestA", "test_a"), - ("TestA", "test_new_a"), - ("TestB", "test_b"), - ("TestC", "test_c"), # entirely new class - } - all_failing = { - ("TestA", "test_a"), # regression - ("TestA", "test_new_a"), # new - ("TestC", "test_c"), # new (new class) - } + all_failing = {("ChildA", "test_foo"), ("ChildB", "test_foo")} + error_messages = {("ChildA", "test_foo"): "RuntimeError: boom"} - to_mark, regressions = self._filter_failures( - all_failing, original_methods, current_methods + to_remark = _expand_stripped_to_children( + stripped_code, stripped_tests, all_failing ) + self.assertEqual(to_remark, all_failing) - self.assertEqual(to_mark, {("TestA", "test_new_a"), ("TestC", "test_c")}) - self.assertEqual(regressions, {("TestA", "test_a")}) + result = apply_test_changes(stripped_code, to_remark, set(), error_messages) + self.assertIn("RuntimeError: boom", result) + self.assertEqual(result.count("@unittest.expectedFailure"), 1) + self.assertNotIn("return super()", result) - def test_all_new_tests(self): - """Test when all failing tests are new (no regressions).""" - original_methods = set() # file was new - current_methods = {("TestFoo", "test_one"), ("TestFoo", "test_two")} - all_failing = {("TestFoo", "test_one"), ("TestFoo", "test_two")} - to_mark, regressions = self._filter_failures( - all_failing, original_methods, current_methods - ) +class TestSmartAutoMarkFiltering(unittest.TestCase): + """Tests for smart auto-mark filtering (new tests vs regressions).""" + + @staticmethod + def _filter(all_failing, original, current): + new = current - original + to_mark = {t for t in all_failing if t in new} + return to_mark, all_failing - to_mark + + def test_new_vs_regression(self): + """New failures are marked; existing (regression) failures are not.""" + original = {("TestFoo", "test_old1"), ("TestFoo", "test_old2")} + current = original | {("TestFoo", "test_new1"), ("TestFoo", "test_new2")} + all_failing = {("TestFoo", "test_old1"), ("TestFoo", "test_new1")} + + to_mark, regressions = self._filter(all_failing, original, current) + self.assertEqual(to_mark, {("TestFoo", "test_new1")}) + self.assertEqual(regressions, {("TestFoo", "test_old1")}) + # Edge: all new → all marked + to_mark, regressions = self._filter(all_failing, set(), current) self.assertEqual(to_mark, all_failing) self.assertEqual(regressions, set()) - def test_all_regressions(self): - """Test when all failing tests are regressions (no new tests).""" - original_methods = {("TestFoo", "test_one"), ("TestFoo", "test_two")} - current_methods = original_methods.copy() - all_failing = {("TestFoo", "test_one")} - - to_mark, regressions = self._filter_failures( - all_failing, original_methods, current_methods - ) - + # Edge: all old → nothing marked + to_mark, regressions = self._filter(all_failing, current, current) self.assertEqual(to_mark, set()) - self.assertEqual(regressions, {("TestFoo", "test_one")}) - + self.assertEqual(regressions, all_failing) -class TestIsSuperCallOnly(unittest.TestCase): - """Tests for _is_super_call_only function.""" + def test_filters_across_classes(self): + original = {("TestA", "test_a"), ("TestB", "test_b")} + current = original | {("TestA", "test_new_a"), ("TestC", "test_c")} + all_failing = { + ("TestA", "test_a"), # regression + ("TestA", "test_new_a"), # new + ("TestC", "test_c"), # new (new class) + } + to_mark, regressions = self._filter(all_failing, original, current) + self.assertEqual(to_mark, {("TestA", "test_new_a"), ("TestC", "test_c")}) + self.assertEqual(regressions, {("TestA", "test_a")}) - def _parse_method(self, code: str): - """Parse code and return the first method.""" - import ast +class TestIsSuperCallOnly(unittest.TestCase): + @staticmethod + def _parse_method(code): tree = ast.parse(code) for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): return node return None - def test_matching_super_call(self): - """Test method that calls super().same_name().""" - code = """ -class Foo: - def test_one(self): - return super().test_one() -""" - method = self._parse_method(code) - self.assertTrue(_is_super_call_only(method)) - - def test_mismatched_super_call(self): - """Test method that calls super().different_name().""" - code = """ -class Foo: - def test_one(self): - return super().test_two() -""" - method = self._parse_method(code) - self.assertFalse(_is_super_call_only(method)) - - def test_not_super_call(self): - """Test method with regular body.""" - code = """ -class Foo: - def test_one(self): - pass -""" - method = self._parse_method(code) - self.assertFalse(_is_super_call_only(method)) - - def test_multiple_statements(self): - """Test method with multiple statements.""" - code = """ + def test_sync(self): + cases = [ + ("return super().test_one()", True), + ("return super().test_two()", False), # mismatched name + ("pass", False), # regular body + ("x = 1\n return super().test_one()", False), # multiple stmts + ] + for body, expected in cases: + with self.subTest(body=body): + code = f""" class Foo: def test_one(self): - x = 1 - return super().test_one() -""" - method = self._parse_method(code) - self.assertFalse(_is_super_call_only(method)) - - def test_async_await_super_call(self): - """Test async method that awaits super().same_name().""" - code = """ -class Foo: - async def test_one(self): - return await super().test_one() -""" - method = self._parse_method(code) - self.assertTrue(_is_super_call_only(method)) - - def test_async_await_mismatched_super_call(self): - """Test async method that awaits super().different_name().""" - code = """ -class Foo: - async def test_one(self): - return await super().test_two() + {body} """ - method = self._parse_method(code) - self.assertFalse(_is_super_call_only(method)) - - def test_async_without_await(self): - """Test async method that calls super() without await (sync super call in async method).""" - code = """ + self.assertEqual( + _is_super_call_only(self._parse_method(code)), expected + ) + + def test_async(self): + cases = [ + ("return await super().test_one()", True), + ("return await super().test_two()", False), + ("return super().test_one()", True), # sync call in async method + ] + for body, expected in cases: + with self.subTest(body=body): + code = f""" class Foo: async def test_one(self): - return super().test_one() -""" - method = self._parse_method(code) - self.assertTrue(_is_super_call_only(method)) - - -class TestAsyncInheritedOverride(unittest.TestCase): - """Tests for async inherited method override generation.""" - - def test_inherited_async_method_generates_async_override(self): - """Test that inherited async methods get async def + await override.""" - code = """import unittest - -class BaseTest: - async def test_async_one(self): - pass - -class TestChild(BaseTest, unittest.TestCase): - pass -""" - failing = {("TestChild", "test_async_one")} - result = apply_test_changes(code, failing, set()) - - self.assertIn("async def test_async_one(self):", result) - self.assertIn("return await super().test_async_one()", result) - self.assertIn("@unittest.expectedFailure", result) - - def test_inherited_sync_method_generates_sync_override(self): - """Test that inherited sync methods get sync def override.""" - code = """import unittest - -class BaseTest: - def test_sync_one(self): - pass - -class TestChild(BaseTest, unittest.TestCase): - pass -""" - failing = {("TestChild", "test_sync_one")} - result = apply_test_changes(code, failing, set()) - - self.assertIn("def test_sync_one(self):", result) - self.assertIn("return super().test_sync_one()", result) - self.assertNotIn("async def test_sync_one", result) - self.assertNotIn("await", result) - - def test_remove_async_super_call_override(self): - """Test removing async super call override on unexpected success.""" - code = f"""import unittest - -class BaseTest: - async def test_async_one(self): - pass - -class TestChild(BaseTest, unittest.TestCase): - # {COMMENT} - @unittest.expectedFailure - async def test_async_one(self): - return await super().test_async_one() + {body} """ - successes = {("TestChild", "test_async_one")} - result = apply_test_changes(code, set(), successes) - - # The override in TestChild should be removed; base class method remains - self.assertNotIn("return await super().test_async_one()", result) - self.assertNotIn("@unittest.expectedFailure", result) - self.assertIn("class TestChild", result) - # Base class method should still be present - self.assertIn("class BaseTest", result) - self.assertIn("async def test_async_one(self):", result) + self.assertEqual( + _is_super_call_only(self._parse_method(code)), expected + ) if __name__ == "__main__":