diff --git a/.github/scripts/test_update_disabled_issues.py b/.github/scripts/test_update_disabled_issues.py index d30eb7f7af..32451855e9 100644 --- a/.github/scripts/test_update_disabled_issues.py +++ b/.github/scripts/test_update_disabled_issues.py @@ -2,8 +2,8 @@ from update_disabled_issues import ( condense_disable_jobs, - condense_disable_tests, filter_disable_issues, + get_disabled_tests, get_disable_issues, OWNER, REPO, @@ -98,13 +98,13 @@ def test_filter_disable_issues(self, mock_get_disable_issues): [item["number"] for item in disabled_jobs], [32132, 42345, 94861] ) - def test_condense_disable_tests(self, mock_get_disable_issues): + def test_get_disable_tests(self, mock_get_disable_issues): mock_get_disable_issues.return_value = MOCK_DATA disabled_issues = get_disable_issues("dummy token") disabled_tests, _ = filter_disable_issues(disabled_issues) - results = condense_disable_tests(disabled_tests) + results = get_disabled_tests(disabled_tests) self.assertDictEqual( { @@ -125,6 +125,83 @@ def test_condense_disable_tests(self, mock_get_disable_issues): results, ) + def test_get_disable_tests_aggregate_issue(self, mock_get_disable_issues): + self.maxDiff = None + mock_data = [ + { + "url": "https://github.com/pytorch/pytorch/issues/32644", + "number": 32644, + "title": "DISABLED MULTIPLE dummy test", + "body": "disable the following tests:\n```\ntest_quantized_nn (test_quantization.PostTrainingDynamicQuantTest): mac, win\ntest_zero_redundancy_optimizer (__main__.TestZeroRedundancyOptimizerDistributed)\n```", + } + ] + disabled_tests = get_disabled_tests(mock_data) + self.assertDictEqual( + { + "test_quantized_nn (test_quantization.PostTrainingDynamicQuantTest)": ( + str(mock_data[0]["number"]), + mock_data[0]["url"], + ["mac", "win"], + ), + "test_zero_redundancy_optimizer (__main__.TestZeroRedundancyOptimizerDistributed)": ( + str(mock_data[0]["number"]), + mock_data[0]["url"], + [], + ), + }, + disabled_tests, + ) + + def test_get_disable_tests_merge_issues(self, mock_get_disable_issues): + self.maxDiff = None + mock_data = [ + { + "url": "https://github.com/pytorch/pytorch/issues/32644", + "number": 32644, + "title": "DISABLED MULTIPLE dummy test", + "body": "disable the following tests:\n```\ntest_2 (abc.ABC): mac, win\ntest_3 (DEF)\n```", + }, + { + "url": "https://github.com/pytorch/pytorch/issues/32645", + "number": 32645, + "title": "DISABLED MULTIPLE dummy test", + "body": "disable the following tests:\n```\ntest_2 (abc.ABC): mac, win, linux\ntest_3 (DEF): mac\n```", + }, + { + "url": "https://github.com/pytorch/pytorch/issues/32646", + "number": 32646, + "title": "DISABLED test_1 (__main__.Test1)", + "body": "platforms: linux", + }, + { + "url": "https://github.com/pytorch/pytorch/issues/32647", + "number": 32647, + "title": "DISABLED test_2 (abc.ABC)", + "body": "platforms: dynamo", + }, + ] + disabled_tests = get_disabled_tests(mock_data) + self.assertDictEqual( + { + "test_2 (abc.ABC)": ( + str(mock_data[3]["number"]), + mock_data[3]["url"], + ["dynamo", "linux", "mac", "win"], + ), + "test_3 (DEF)": ( + str(mock_data[1]["number"]), + mock_data[1]["url"], + [], + ), + "test_1 (__main__.Test1)": ( + str(mock_data[2]["number"]), + mock_data[2]["url"], + ["linux"], + ), + }, + disabled_tests, + ) + def test_condense_disable_jobs(self, mock_get_disable_issues): mock_get_disable_issues.return_value = MOCK_DATA diff --git a/.github/scripts/update_disabled_issues.py b/.github/scripts/update_disabled_issues.py index 1ef1b31d3f..238413df19 100755 --- a/.github/scripts/update_disabled_issues.py +++ b/.github/scripts/update_disabled_issues.py @@ -156,10 +156,15 @@ def update_disabled_tests( if key not in disabled_tests: disabled_tests[key] = (number, url, platforms_to_skip) else: + original_platforms = disabled_tests[key][2] + if len(original_platforms) == 0 or len(platforms_to_skip) == 0: + platforms = [] + else: + platforms = sorted(set(original_platforms + platforms_to_skip)) disabled_tests[key] = ( number, url, - list(set(disabled_tests[key][2] + platforms_to_skip)), + platforms, ) test_name_regex = re.compile(r"(test_[a-zA-Z0-9-_\.]+)\s+\(([a-zA-Z0-9-_\.]+)\)") @@ -196,8 +201,7 @@ def parse_test_name(s: str) -> Optional[str]: if "```" in line: break split_by_colon = line.split(":") - if len(split_by_colon) != 2: - continue + test_name = parse_test_name(split_by_colon[0].strip()) if test_name is None: continue @@ -205,7 +209,12 @@ def parse_test_name(s: str) -> Optional[str]: test_name, number, url, - get_platforms_to_skip(split_by_colon[1].strip(), ""), + get_platforms_to_skip( + split_by_colon[1].strip() + if len(split_by_colon) > 1 + else "", + "", + ), ) else: print(f"Unknown disable issue type: {title}")