Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make accept_plugin viable, add a way to suppress assertions #57

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions pytest_accept/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
from .doctest_plugin import (
pytest_addoption,
pytest_collect_file,
pytest_configure,
pytest_runtest_makereport,
pytest_sessionfinish,
)

# Note that pluginsn need to be listed here in order for pytest to pick them up when
# this package is installed.

__all__ = [
pytest_runtest_makereport,
pytest_sessionfinish,
pytest_addoption,
pytest_configure,
pytest_collect_file,
]
import pytest


def pytest_sessionstart(session):
from .assert_plugin import pytest_sessionstart

pytest_sessionstart(session)


def pytest_sessionfinish(session, exitstatus):
from .assert_plugin import pytest_sessionfinish

pytest_sessionfinish(session, exitstatus)


def pytest_assertrepr_compare(config, op, left, right):
from .assert_plugin import pytest_assertrepr_compare

pytest_assertrepr_compare(config, op, left, right)


def pytest_addoption(parser):
"""Add pytest-accept options to pytest"""
group = parser.getgroup("accept", "accept test plugin")
group.addoption(
"--accept",
dest="ACCEPT",
default="",
help="Write a .new file with new file contents ('new'), or overwrite the original test file ('overwrite')",
)
group.addoption(
"--accept-continue",
action="store_true",
default=False,
help="Continue after the first test failure",
)
109 changes: 89 additions & 20 deletions pytest_accept/assert_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def test_x():
assert 2 == 3
```

...whether or not it overwrites the original file or creates a new file with a
`.new` suffix is controlled by the `OVERWRITE` constant in `conftest.py`.

## Current shortcomings

### Big ones
Expand Down Expand Up @@ -73,76 +70,148 @@ def test_x():
we can make it one.
"""


import ast
import copy
import logging
import sys
from collections import defaultdict
from datetime import datetime
from typing import Dict, List, Tuple

import astor
import pytest
import yaml
from _pytest._code.code import ExceptionInfo

logger = logging.getLogger(__name__)

# Dict of {path: list of (location, new code)}
asts_modified: Dict[str, List[Tuple[slice, str]]] = defaultdict(list)

OVERWRITE = False
INTERCEPT_ASSERTIONS = False

_ASSERTION_HANDLER = ast.parse(
"""
__import__("pytest_accept").assert_plugin.__handle_failed_assertion()
"""
).body


def _patch_assertion_rewriter():
# I'm so sorry.

from _pytest.assertion.rewrite import AssertionRewriter

old_visit_assert = AssertionRewriter.visit_Assert

def new_visit_assert(self, assert_):
rv = old_visit_assert(self, assert_)

try_except = ast.Try(
body=rv,
handlers=[
ast.ExceptHandler(
expr=AssertionError,
identifier="__pytest_accept_e",
body=_ASSERTION_HANDLER,
)
],
orelse=[],
finalbody=[],
)

ast.copy_location(try_except, assert_)
for node in ast.iter_child_nodes(try_except):
ast.copy_location(node, assert_)

return [try_except]

AssertionRewriter.visit_Assert = new_visit_assert

@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_makereport(item, call):
outcome = yield

if not call.excinfo or not isinstance(call.excinfo.value, AssertionError):
_patch_assertion_rewriter()


def __handle_failed_assertion():
raw_excinfo = sys.exc_info()
if raw_excinfo is None:
return

__handle_failed_assertion_impl(raw_excinfo)

if not INTERCEPT_ASSERTIONS:
raise


def __handle_failed_assertion_impl(raw_excinfo):
excinfo = ExceptionInfo.from_exc_info(raw_excinfo)

op, left, _ = recent_failure.pop()
if op != "==":
logger.debug(f"{item.nodeid} does not assert equality, and won't be replaced")
logger.debug("does not assert equality, and won't be replaced")
return

tb_entry = call.excinfo.traceback[0]
tb_entry = excinfo.traceback[0]
# not exactly sure why +1, but in tb_entry.__repr__
line_number_start = tb_entry.lineno + 1
line_number_end = line_number_start + len(tb_entry.statement.lines) - 1
original_location = slice(line_number_start, line_number_end)

path = tb_entry.path
tree = ast.parse(path.read())
tree = ast.parse(path.open().read())

for item in ast.walk(tree):
if isinstance(item, ast.Assert) and original_location.start == item.lineno:
# we need to _then_ check that the next compare item's
# ops[0] is Eq and then replace the comparator[0]
assert item.msg is None
assert len(item.test.comparators) == 1
assert len(item.test.ops) == 1
assert isinstance(item.test.ops[0], ast.Eq)
try:
assert item.msg is None
assert len(item.test.comparators) == 1
assert len(item.test.ops) == 1
assert isinstance(item.test.ops[0], ast.Eq)

ast.literal_eval(item.test.comparators[0])
except Exception:
continue

new_assert = copy.copy(item)
new_assert.test.comparators[0] = ast.Constant(value=left)

asts_modified[path].append((original_location, new_assert))
return outcome.get_result()
asts_modified[path].append((original_location, new_assert))


recent_failure: List[Tuple] = []


def pytest_assertrepr_compare(config, op, left, right):

recent_failure.append((op, left, right))


def pytest_sessionstart(session):
global INTERCEPT_ASSERTIONS
INTERCEPT_ASSERTIONS = session.config.getoption("--accept-continue")


def pytest_sessionfinish(session, exitstatus):
passed_accept = session.config.getoption("--accept")
if not passed_accept:
return

for path, new_asserts in asts_modified.items():
original = list(open(path).readlines())
# sort by line number
new_asserts = sorted(new_asserts, key=lambda x: x[0].start)

file = open(path + (".new" if not OVERWRITE else ""), "w+")
file = open(
str(path)
+ (
{
"new": ".new",
"overwrite": "",
}[passed_accept]
),
"w+",
)

for i, line in enumerate(original):
line_no = i + 1
Expand Down
Loading