Skip to content

Commit

Permalink
Support mask= argument in LocBody (#566)
Browse files Browse the repository at this point in the history
* A rough idea for supporting the `mask` argument in `GT.tab_style()`

* Ensure the `mask=` argument is used exclusively without specifying `columns` or `rows`

* Add tests for `resolve_mask_i()`

* Update test name

* Replace ambiguous variable name `masks` with `cellpos_data`

* Update the `resolve_mask_i()` logic based on team feedback

* Replace `assert` with `raise ValueError()`

* Update test cases for `GT.tab_style()`

* Add additional test cases for `resolve_mask_i()`

* Add docstring for `mask=` in `LocBody`

* Use `df.height` to get the number of rows in a DataFrame

* Rename `resolve_mask_i` to `resolve_mask`

* Apply code review suggestions for the `mask=` implementation in `LocBody`
  • Loading branch information
jrycw authored Jan 22, 2025
1 parent ef7d2ea commit 15b293e
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 8 deletions.
78 changes: 70 additions & 8 deletions great_tables/_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ class LocBody(Loc):
rows
The rows to target. Can either be a single row name or a series of row names provided in a
list.
mask
The cells to target. If the underlying wrapped DataFrame is a Polars DataFrame,
you can pass a Polars expression for cell-based selection. This argument must be used
exclusively and cannot be combined with the `columns=` or `rows=` arguments.
:::{.callout-warning}
`mask=` is still experimental.
:::
Returns
-------
Expand Down Expand Up @@ -539,6 +547,7 @@ class LocBody(Loc):

columns: SelectExpr = None
rows: RowSelectExpr = None
mask: PlExpr | None = None


@dataclass
Expand Down Expand Up @@ -823,6 +832,52 @@ def resolve_rows_i(
)


def resolve_mask(
data: GTData | list[str],
expr: PlExpr,
excl_stub: bool = True,
excl_group: bool = True,
) -> list[tuple[int, int, str]]:
"""Return data for creating `CellPos`, based on expr"""
if not isinstance(expr, PlExpr):
raise ValueError("Only Polars expressions can be passed to the `mask` argument.")

frame: PlDataFrame = data._tbl_data
frame_cols = frame.columns

stub_var = data._boxhead.vars_from_type(ColInfoTypeEnum.stub)
group_var = data._boxhead.vars_from_type(ColInfoTypeEnum.row_group)
cols_excl = [*(stub_var if excl_stub else []), *(group_var if excl_group else [])]

# `df.select()` raises `ColumnNotFoundError` if columns are missing from the original DataFrame.
masked = frame.select(expr).drop(cols_excl, strict=False)

# Validate that `masked.columns` exist in the `frame_cols`
missing = set(masked.columns) - set(frame_cols)
if missing:
raise ValueError(
"The `mask` expression produces extra columns, with names not in the original DataFrame."
f"\n\nExtra columns: {missing}"
)

# Validate that row lengths are equal
if masked.height != frame.height:
raise ValueError(
"The DataFrame length after applying `mask` differs from the original."
"\n\n* Original length: {frame.height}"
"\n* Mask length: {masked.height}"
)

cellpos_data: list[tuple[int, int, str]] = [] # column, row, colname for `CellPos`
col_idx_map = {colname: frame_cols.index(colname) for colname in frame_cols}
for row_idx, row_dict in enumerate(masked.iter_rows(named=True)):
for colname, value in row_dict.items():
if value: # select only when `value` is True
col_idx = col_idx_map[colname]
cellpos_data.append((col_idx, row_idx, colname))
return cellpos_data


# Resolve generic ======================================================================


Expand Down Expand Up @@ -868,15 +923,22 @@ def _(loc: LocStub, data: GTData) -> set[int]:

@resolve.register
def _(loc: LocBody, data: GTData) -> list[CellPos]:
cols = resolve_cols_i(data=data, expr=loc.columns)
rows = resolve_rows_i(data=data, expr=loc.rows)

# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
# thing multiple times
cell_pos = [
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
]
if (loc.columns is not None or loc.rows is not None) and loc.mask is not None:
raise ValueError(
"Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."
)

if loc.mask is None:
rows = resolve_rows_i(data=data, expr=loc.rows)
cols = resolve_cols_i(data=data, expr=loc.columns)
# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
# thing multiple times
cell_pos = [
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
]
else:
cellpos_data = resolve_mask(data=data, expr=loc.mask)
cell_pos = [CellPos(*cellpos) for cellpos in cellpos_data]
return cell_pos


Expand Down
76 changes: 76 additions & 0 deletions tests/test_tab_create_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
from great_tables._locations import LocBody
from great_tables._styles import CellStyleFill
from great_tables._tab_create_modify import tab_style
from polars import selectors as cs


@pytest.fixture
def gt():
return GT(pd.DataFrame({"x": [1, 2], "y": [4, 5]}))


@pytest.fixture
def gt2():
return GT(pl.DataFrame({"x": [1, 2], "y": [4, 5]}))


def test_tab_style(gt: GT):
style = CellStyleFill(color="blue")
new_gt = tab_style(gt, style, LocBody(["x"], [0]))
Expand Down Expand Up @@ -71,3 +77,73 @@ def test_tab_style_font_from_column():

assert rendered_html.find('<td style="font-family: Helvetica;" class="gt_row gt_right">1</td>')
assert rendered_html.find('<td style="font-family: Courier;" class="gt_row gt_right">2</td>')


def test_tab_style_loc_body_mask(gt2: GT):
style = CellStyleFill(color="blue")
new_gt = tab_style(gt2, style, LocBody(mask=cs.numeric().gt(1.5)))

assert len(gt2._styles) == 0
assert len(new_gt._styles) == 3

xy_0y, xy_1x, xy_1y = new_gt._styles

assert xy_0y.styles[0] is style
assert xy_1x.styles[0] is style
assert xy_1y.styles[0] is style

assert xy_0y.rownum == 0
assert xy_0y.colname == "y"

assert xy_1x.rownum == 1
assert xy_1x.colname == "x"

assert xy_1y.rownum == 1
assert xy_1y.colname == "y"


def test_tab_style_loc_body_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = cs.numeric().gt(1.5)
err_msg = "Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(columns=["x"], mask=mask))
assert err_msg in exc_info.value.args[0]

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(rows=[0], mask=mask))

assert err_msg in exc_info.value.args[0]


def test_tab_style_loc_body_mask_not_polars_expression_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = "fake expression"
err_msg = "Only Polars expressions can be passed to the `mask` argument."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
assert err_msg in exc_info.value.args[0]


def test_tab_style_loc_body_mask_columns_not_inside_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = pl.len()
err_msg = (
"The `mask` expression produces extra columns, with names not in the original DataFrame."
)

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
assert err_msg in exc_info.value.args[0]


def test_tab_style_loc_body_mask_rows_not_equal_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = pl.len().alias("x")
err_msg = "The DataFrame length after applying `mask` differs from the original."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
assert err_msg in exc_info.value.args[0]

0 comments on commit 15b293e

Please sign in to comment.