Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't use _outer_type if we don't have to #4528

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,16 @@ def __init_subclass__(cls, **kwargs):
if field.name not in props:
continue

field_type = types.value_inside_optional(
types.get_field_type(cls, field.name)
)

# Set default values for any props.
if types._issubclass(field.type_, Var):
if types._issubclass(field_type, Var):
field.required = False
if field.default is not None:
field.default = LiteralVar.create(field.default)
elif types._issubclass(field.type_, EventHandler):
elif types._issubclass(field_type, EventHandler):
field.required = False

# Ensure renamed props from parent classes are applied to the subclass.
Expand Down Expand Up @@ -426,7 +430,9 @@ def __init__(self, *args, **kwargs):
field_type = EventChain
elif key in props:
# Set the field type.
field_type = fields[key].type_
field_type = types.value_inside_optional(
types.get_field_type(type(self), key)
)

else:
continue
Expand All @@ -446,7 +452,10 @@ def __init__(self, *args, **kwargs):
if kwargs[key] is None:
raise TypeError

expected_type = fields[key].outer_type_.__args__[0]
expected_type = types.get_args(
types.get_field_type(type(self), key)
)[0]

# validate literal fields.
types.validate_literal(
key, value, expected_type, type(self).__name__
Expand All @@ -461,7 +470,7 @@ def __init__(self, *args, **kwargs):
except TypeError:
# If it is not a valid var, check the base types.
passed_type = type(value)
expected_type = fields[key].outer_type_
expected_type = types.get_field_type(type(self), key)
if types.is_union(passed_type):
# We need to check all possible types in the union.
passed_types = (
Expand Down Expand Up @@ -674,8 +683,11 @@ def get_event_triggers(

# Look for component specific triggers,
# e.g. variable declared as EventHandler types.
for field in self.get_fields().values():
if types._issubclass(field.outer_type_, EventHandler):
for name, field in self.get_fields().items():
if types._issubclass(
types.value_inside_optional(types.get_field_type(type(self), name)),
EventHandler,
):
args_spec = None
annotation = field.annotation
if (metadata := getattr(annotation, "__metadata__", None)) is not None:
Expand Down Expand Up @@ -787,9 +799,11 @@ def get_component_props(cls) -> set[str]:
"""
return {
name
for name, field in cls.get_fields().items()
for name in cls.get_fields()
if name in cls.get_props()
and types._issubclass(field.outer_type_, Component)
and types._issubclass(
types.value_inside_optional(types.get_field_type(cls, name)), Component
)
}

@classmethod
Expand Down
11 changes: 9 additions & 2 deletions reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
from typing_extensions import Annotated, get_type_hints

from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
from reflex.utils.types import GenericType, is_union, value_inside_optional
from reflex.utils.types import (
GenericType,
is_union,
true_type_for_pydantic_field,
value_inside_optional,
)

try:
import pydantic.v1 as pydantic
Expand Down Expand Up @@ -759,7 +764,9 @@ def update_from_env(self) -> dict[str, Any]:
# If the env var is set, override the config value.
if env_var is not None:
# Interpret the value.
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
value = interpret_env_var_value(
env_var, true_type_for_pydantic_field(field), field.name
)

# Set the value.
updated_values[key] = value
Expand Down
10 changes: 4 additions & 6 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@
from reflex.utils.types import (
_isinstance,
get_origin,
is_optional,
is_union,
override,
true_type_for_pydantic_field,
value_inside_optional,
)
from reflex.vars import VarData
Expand Down Expand Up @@ -282,7 +282,7 @@ def __call__(self, *args: Any) -> EventSpec:
from pydantic.v1.fields import ModelField


def _unwrap_field_type(type_: Type) -> Type:
def _unwrap_field_type(type_: types.GenericType) -> Type:
"""Unwrap rx.Field type annotations.

Args:
Expand Down Expand Up @@ -313,7 +313,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
return dispatch(
field_name=field_name,
var_data=VarData.from_state(cls, f.name),
result_var_type=_unwrap_field_type(f.outer_type_),
result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
)


Expand Down Expand Up @@ -1329,9 +1329,7 @@ def __setattr__(self, name: str, value: Any):

if name in fields:
field = fields[name]
field_type = _unwrap_field_type(field.outer_type_)
if field.allow_none and not is_optional(field_type):
field_type = Union[field_type, None]
field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
if not _isinstance(value, field_type):
console.deprecate(
"mismatched-type-assignment",
Expand Down
57 changes: 40 additions & 17 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Callable,
ClassVar,
Dict,
ForwardRef,
FrozenSet,
Iterable,
List,
Expand Down Expand Up @@ -269,6 +270,20 @@ def is_optional(cls: GenericType) -> bool:
return is_union(cls) and type(None) in get_args(cls)


def true_type_for_pydantic_field(f: ModelField):
"""Get the type for a pydantic field.

Args:
f: The field to get the type for.

Returns:
The type for the field.
"""
if not isinstance(f.annotation, (str, ForwardRef)):
return f.annotation
return f.outer_type_


def value_inside_optional(cls: GenericType) -> GenericType:
"""Get the value inside an Optional type or the original type.

Expand All @@ -283,6 +298,29 @@ def value_inside_optional(cls: GenericType) -> GenericType:
return cls


def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
"""Get the type of a field in a class.

Args:
cls: The class to check.
field_name: The name of the field to check.

Returns:
The type of the field, if it exists, else None.
"""
if (
hasattr(cls, "__fields__")
and field_name in cls.__fields__
and hasattr(cls.__fields__[field_name], "annotation")
and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
):
return cls.__fields__[field_name].annotation
type_hints = get_type_hints(cls)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_type_hints is a really slow function, it doesn't make sense to call it multiple times for the same cls without some kind of caching.

we should profile and see what affect this change may or may not have on component __init__

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea i was thinking the same, we could also be more specific when we use this vs the other one

if field_name in type_hints:
return type_hints[field_name]
return None


def get_property_hint(attr: Any | None) -> GenericType | None:
"""Check if an attribute is a property and return its type hint.

Expand Down Expand Up @@ -320,24 +358,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
if hint := get_property_hint(attr):
return hint

if (
hasattr(cls, "__fields__")
and name in cls.__fields__
and hasattr(cls.__fields__[name], "outer_type_")
):
if hasattr(cls, "__fields__") and name in cls.__fields__:
# pydantic models
field = cls.__fields__[name]
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if (
not field.required
and field.default is None
and field.default_factory is None
):
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
Comment on lines -329 to -340
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm really hesitant about removing these lines, unless we're sure this is not going to be hit.

@benedikt-bartscher do you know under what circumstances this bit of code gets used?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import reflex as rx


class M1(rx.Base):
    foo: str = "foo"


class M2(rx.Base):
    m1: M1 = None


class State(rx.State):
    m2: M2 = M2()


def index() -> rx.Component:
    return rx.container(
        rx.text(State.m2.m1.foo),
    )


app = rx.App()
app.add_page(index)

This is at least one of the cases; default is None, but the field is not annotated as Optional (which is incorrect and i'm surprised pydantic doesn't yell), but Reflex still try to be helpful and use "null coalescing" to avoid frontend exceptions.

Again, why pydantic allows the default to not match the annotation and pass validation i'm unsure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is interesting, default is None when it's unset and when it's set to None :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masenf sorry, I did not see your question for some reason. But it seems like you found a case.

This one is interesting, default is None when it's unset and when it's set to None :(

@adhami3310 you can use this to determine unset defaults

from pydantic.v1 import BaseModel


class User(BaseModel):
    no_default: str | None
    none_default: str | None = None


no_default_field = User.__fields__["no_default"]
none_default_field = User.__fields__["none_default"]
print(no_default_field.field_info.default)
print(none_default_field.field_info.default)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL PydanticUndefined

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last time i checked, FieldInfo doesn't contain the type, but I wouldn't be surprised if i missed something

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean for default and default factory values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i made the check for default use None, there's little I can do with default_factory returning None (i hope no one thinks that's a good idea)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is a bad idea?

diff --git a/reflex/state.py b/reflex/state.py
index e7e6bcf3..46ba3133 100644
--- a/reflex/state.py
+++ b/reflex/state.py
@@ -1099,7 +1099,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
         if (
             not field.required
             and field.default is None
-            and field.default_factory is None
+            and (
+                field.default_factory is None
+                or types.default_factory_can_be_none(field.default_factory)
+            )
             and not types.is_optional(prop._var_type)
         ):
             # Ensure frontend uses null coalescing when accessing.
diff --git a/reflex/utils/types.py b/reflex/utils/types.py
index b8bcbf2d..60eb456d 100644
--- a/reflex/utils/types.py
+++ b/reflex/utils/types.py
@@ -333,7 +333,10 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
         if (
             not field.required
             and field.default is None
-            and field.default_factory is None
+            and (
+                field.default_factory is None
+                or default_factory_can_be_none(field.default_factory)
+            )
         ):
             # Ensure frontend uses null coalescing when accessing.
             type_ = Optional[type_]
@@ -893,3 +896,18 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
         for provided_arg, accepted_arg in zip(provided_args, accepted_args)
         if accepted_arg is not Any
     )
+
+
+def default_factory_can_be_none(default_factory: Callable) -> bool:
+    """Check if the default factory can return None.
+
+    Args:
+        default_factory: The default factory to check.
+
+    Returns:
+        Whether the default factory can return None.
+    """
+    type_hints = get_type_hints(default_factory)
+    if hint := type_hints.get("return"):
+        return is_optional(hint)
+    return default_factory() is None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most default factories are lambdas so i would imagine you're not getting a type hint in most cases, calling it could give you should be a reliable indicator but i'm afraid of side effects or performance costs

return get_field_type(cls, name)
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
insp = sqlalchemy.inspect(cls)
if name in insp.columns:
Expand Down
Loading