Skip to content

Commit

Permalink
tc
Browse files Browse the repository at this point in the history
  • Loading branch information
clee2000 committed Dec 20, 2024
1 parent a7cc124 commit 1e91120
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 7 deletions.
83 changes: 80 additions & 3 deletions .github/scripts/test_update_disabled_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from update_disabled_issues import (
condense_disable_jobs,
condense_disable_tests,
filter_disable_issues,
get_disabled_tests,
get_disable_issues,
OWNER,
REPO,
Expand Down Expand Up @@ -98,13 +98,13 @@ def test_filter_disable_issues(self, mock_get_disable_issues):
[item["number"] for item in disabled_jobs], [32132, 42345, 94861]
)

def test_condense_disable_tests(self, mock_get_disable_issues):
def test_get_disable_tests(self, mock_get_disable_issues):
mock_get_disable_issues.return_value = MOCK_DATA

disabled_issues = get_disable_issues("dummy token")

disabled_tests, _ = filter_disable_issues(disabled_issues)
results = condense_disable_tests(disabled_tests)
results = get_disabled_tests(disabled_tests)

self.assertDictEqual(
{
Expand All @@ -125,6 +125,83 @@ def test_condense_disable_tests(self, mock_get_disable_issues):
results,
)

def test_get_disable_tests_aggregate_issue(self, mock_get_disable_issues):
self.maxDiff = None
mock_data = [
{
"url": "https://github.com/pytorch/pytorch/issues/32644",
"number": 32644,
"title": "DISABLED MULTIPLE dummy test",
"body": "disable the following tests:\n```\ntest_quantized_nn (test_quantization.PostTrainingDynamicQuantTest): mac, win\ntest_zero_redundancy_optimizer (__main__.TestZeroRedundancyOptimizerDistributed)\n```",
}
]
disabled_tests = get_disabled_tests(mock_data)
self.assertDictEqual(
{
"test_quantized_nn (test_quantization.PostTrainingDynamicQuantTest)": (
str(mock_data[0]["number"]),
mock_data[0]["url"],
["mac", "win"],
),
"test_zero_redundancy_optimizer (__main__.TestZeroRedundancyOptimizerDistributed)": (
str(mock_data[0]["number"]),
mock_data[0]["url"],
[],
),
},
disabled_tests,
)

def test_get_disable_tests_merge_issues(self, mock_get_disable_issues):
self.maxDiff = None
mock_data = [
{
"url": "https://github.com/pytorch/pytorch/issues/32644",
"number": 32644,
"title": "DISABLED MULTIPLE dummy test",
"body": "disable the following tests:\n```\ntest_2 (abc.ABC): mac, win\ntest_3 (DEF)\n```",
},
{
"url": "https://github.com/pytorch/pytorch/issues/32645",
"number": 32645,
"title": "DISABLED MULTIPLE dummy test",
"body": "disable the following tests:\n```\ntest_2 (abc.ABC): mac, win, linux\ntest_3 (DEF): mac\n```",
},
{
"url": "https://github.com/pytorch/pytorch/issues/32646",
"number": 32646,
"title": "DISABLED test_1 (__main__.Test1)",
"body": "platforms: linux",
},
{
"url": "https://github.com/pytorch/pytorch/issues/32647",
"number": 32647,
"title": "DISABLED test_2 (abc.ABC)",
"body": "platforms: dynamo",
},
]
disabled_tests = get_disabled_tests(mock_data)
self.assertDictEqual(
{
"test_2 (abc.ABC)": (
str(mock_data[3]["number"]),
mock_data[3]["url"],
["dynamo", "linux", "mac", "win"],
),
"test_3 (DEF)": (
str(mock_data[1]["number"]),
mock_data[1]["url"],
[],
),
"test_1 (__main__.Test1)": (
str(mock_data[2]["number"]),
mock_data[2]["url"],
["linux"],
),
},
disabled_tests,
)

def test_condense_disable_jobs(self, mock_get_disable_issues):
mock_get_disable_issues.return_value = MOCK_DATA

Expand Down
17 changes: 13 additions & 4 deletions .github/scripts/update_disabled_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,15 @@ def update_disabled_tests(
if key not in disabled_tests:
disabled_tests[key] = (number, url, platforms_to_skip)
else:
original_platforms = disabled_tests[key][2]
if len(original_platforms) == 0 or len(platforms_to_skip) == 0:
platforms = []
else:
platforms = sorted(set(original_platforms + platforms_to_skip))
disabled_tests[key] = (
number,
url,
list(set(disabled_tests[key][2] + platforms_to_skip)),
platforms,
)

test_name_regex = re.compile(r"(test_[a-zA-Z0-9-_\.]+)\s+\(([a-zA-Z0-9-_\.]+)\)")
Expand Down Expand Up @@ -196,16 +201,20 @@ def parse_test_name(s: str) -> Optional[str]:
if "```" in line:
break
split_by_colon = line.split(":")
if len(split_by_colon) != 2:
continue

test_name = parse_test_name(split_by_colon[0].strip())
if test_name is None:
continue
update_disabled_tests(
test_name,
number,
url,
get_platforms_to_skip(split_by_colon[1].strip(), ""),
get_platforms_to_skip(
split_by_colon[1].strip()
if len(split_by_colon) > 1
else "",
"",
),
)
else:
print(f"Unknown disable issue type: {title}")
Expand Down

0 comments on commit 1e91120

Please sign in to comment.