Skip to content

Commit

Permalink
Add dataclasses.equal(), @data_eq
Browse files Browse the repository at this point in the history
* Add custom dataclass unit test
  • Loading branch information
holl- committed Dec 19, 2024
1 parent ddf4ca9 commit a733c22
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 6 deletions.
2 changes: 1 addition & 1 deletion phiml/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@

from functools import cached_property

from ._dataclasses import sliceable, data_fields, non_data_fields, config_fields, special_fields, replace, getitem
from ._dataclasses import sliceable, data_fields, non_data_fields, config_fields, special_fields, replace, getitem, equal, data_eq

__all__ = [key for key in globals().keys() if not key.startswith('_')]
41 changes: 37 additions & 4 deletions phiml/dataclasses/_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from phiml.dataclasses._dep import get_unchanged_cache
from phiml.math import DimFilter, shape, Shape
from phiml.math._magic_ops import slice_, variable_attributes
from phiml.math._tensors import disassemble_tree, Tensor, assemble_tree
from phiml.math._tensors import disassemble_tree, Tensor, assemble_tree, equality_by_shape_and_value, equality_by_ref
from phiml.math.magic import slicing_dict, BoundDim

PhiMLDataclass = TypeVar("PhiMLDataclass")
Expand Down Expand Up @@ -52,10 +52,23 @@ def __dataclass_repr__(obj):
return f"{type(obj).__name__}[{content}]"
cls.__repr__ = __dataclass_repr__
return cls
return wrap(cls) if cls is not None else wrap # See if we're being called as @dataclass or @dataclass().

if cls is None: # See if we're being called as @dataclass or @dataclass().
return wrap
return wrap(cls)

def data_eq(cls=None, /, *, rel_tolerance=0., abs_tolerance=0., equal_nan=True, compare_tensors_by_ref=False):
def wrap(cls):
assert cls.__dataclass_params__.eq, f"@data_eq can only be used with dataclasses with eq=True."
cls.__default_dataclass_eq__ = cls.__eq__
def __tensor_eq__(obj, other):
if compare_tensors_by_ref:
with equality_by_ref():
return cls.__default_dataclass_eq__(obj, other)
with equality_by_shape_and_value(rel_tolerance, abs_tolerance, equal_nan):
return cls.__default_dataclass_eq__(obj, other)
cls.__eq__ = __tensor_eq__
# __ne__ calls `not __eq__()` by default
return cls
return wrap(cls) if cls is not None else wrap # See if we're being called as @dataclass or @dataclass().


NON_ATTR_TYPES = str, int, float, complex, bool, Shape, slice, Callable
Expand Down Expand Up @@ -236,3 +249,23 @@ def __getitem__(self, item):
cache = {k: slice_(v, item) for k, v in obj.__dict__.items() if isinstance(getattr(type(obj), k, None), cached_property) and not isinstance(v, Shape)}
new_obj.__dict__.update(cache)
return new_obj


def equal(obj1, obj2, rel_tolerance=0., abs_tolerance=0., equal_nan=True):
"""
Checks if two
Args:
obj1:
obj2:
rel_tolerance:
abs_tolerance:
equal_nan:
Returns:
"""
cls = type(obj1)
eq_fn = cls.__default_dataclass_eq__ if hasattr(cls, '__default_dataclass_eq__') else cls.__eq__
with equality_by_shape_and_value(rel_tolerance, abs_tolerance, equal_nan):
return eq_fn(obj1, obj2)
2 changes: 1 addition & 1 deletion phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def equality_by_ref():


@contextmanager
def equality_by_shape_and_value(rel_tolerance=0, abs_tolerance=0, equal_nan=False):
def equality_by_shape_and_value(rel_tolerance=0., abs_tolerance=0., equal_nan=False):
"""
Enables Tensor.__bool__
"""
Expand Down
Empty file.
48 changes: 48 additions & 0 deletions tests/commit/dataclasses/test_dataclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from dataclasses import dataclass
from typing import Tuple, Sequence, Dict
from unittest import TestCase

from phiml.dataclasses import data_fields, config_fields, special_fields, sliceable, data_eq, equal
from phiml.math import Tensor, vec, wrap, shape, assert_close


class TestDataclasses(TestCase):

def test_field_types(self):
@dataclass
class Custom:
variable_attrs: Sequence[str]
age: int
next: 'Custom'
conf: Dict[str, Sequence[Tuple[str, float, complex, bool, int, slice]]]
data_names = [f.name for f in data_fields(Custom)]
self.assertEqual(['next'], data_names)
config_names = [f.name for f in config_fields(Custom)]
self.assertEqual(['age', 'conf'], config_names)
special_names = [f.name for f in special_fields(Custom)]
self.assertEqual(['variable_attrs'], special_names)

def test_sliceable(self):
@sliceable
@dataclass(frozen=True)
class Custom:
pos: Tensor
edges: Dict[str, Tensor]
c = Custom(vec(x=1, y=2), {'lo': wrap([-1, 1], 'b:b')})
self.assertEqual(('b', 'vector'), shape(c).names)
assert_close(c['y,x'].pos, c.pos['y,x'])
assert_close(c.vector['y,x'].pos, c.pos['y,x'])

def test_data_eq(self):
@data_eq(abs_tolerance=.2)
@dataclass(frozen=True)
class Custom:
pos: Tensor
c1 = Custom(vec(x=0, y=1))
c2 = Custom(vec(x=3, y=4))
c11 = Custom(vec(x=.1, y=1.1))
self.assertNotEqual(c1, c2)
self.assertEqual(c1, c1)
self.assertEqual(c1, c11)
self.assertFalse(equal(c1, c11, abs_tolerance=0))
self.assertTrue(equal(c1, c2, abs_tolerance=3))

0 comments on commit a733c22

Please sign in to comment.