Skip to content

Commit

Permalink
Liniting
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonDeMeester committed Nov 14, 2023
1 parent c21ff69 commit 254fb13
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 41 deletions.
1 change: 1 addition & 0 deletions docs_src/tutorial/fastapi/app_testing/tutorial001/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def get_session():
def on_startup():
create_db_and_tables()


@app.post("/heroes/", response_model=HeroRead)
def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate):
if IS_PYDANTIC_V2:
Expand Down
2 changes: 1 addition & 1 deletion docs_src/tutorial/fastapi/read_one/tutorial001.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from fastapi import FastAPI, HTTPException
from sqlmodel import Field, Session, SQLModel, create_engine, select

from sqlmodel.compat import IS_PYDANTIC_V2


class HeroBase(SQLModel):
name: str = Field(index=True)
secret_name: str
Expand Down
39 changes: 25 additions & 14 deletions sqlmodel/compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from types import NoneType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -11,7 +12,6 @@
get_args,
get_origin,
)
from types import NoneType

from pydantic import VERSION as PYDANTIC_VERSION

Expand All @@ -20,11 +20,12 @@

if IS_PYDANTIC_V2:
from pydantic import ConfigDict
from pydantic_core import PydanticUndefined as PydanticUndefined, PydanticUndefinedType as PydanticUndefinedType # noqa
from pydantic_core import PydanticUndefined as PydanticUndefined # noqa
from pydantic_core import PydanticUndefinedType as PydanticUndefinedType
else:
from pydantic import BaseConfig # noqa
from pydantic.fields import ModelField # noqa
from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON
from pydantic import BaseConfig # noqa
from pydantic.fields import ModelField # noqa
from pydantic.fields import Undefined as PydanticUndefined, SHAPE_SINGLETON
from pydantic.typing import resolve_annotations

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,26 +69,29 @@ def get_config_value(


def set_config_value(
model: InstanceOrType["SQLModel"], parameter: str, value: Any, v1_parameter: str = None
model: InstanceOrType["SQLModel"],
parameter: str,
value: Any,
v1_parameter: str = None,
) -> None:
if IS_PYDANTIC_V2:
model.model_config[parameter] = value # type: ignore
model.model_config[parameter] = value # type: ignore
else:
setattr(model.__config__, v1_parameter or parameter, value) # type: ignore


def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]:
if IS_PYDANTIC_V2:
return model.model_fields # type: ignore
return model.model_fields # type: ignore
else:
return model.__fields__ # type: ignore
return model.__fields__ # type: ignore


def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]:
if IS_PYDANTIC_V2:
return model.__pydantic_fields_set__
else:
return model.__fields_set__ # type: ignore
return model.__fields_set__ # type: ignore


def set_fields_set(
Expand All @@ -103,13 +107,17 @@ def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None:
if IS_PYDANTIC_V2:
cls.model_config["read_from_attributes"] = True
else:
cls.__config__.read_with_orm_mode = True # type: ignore
cls.__config__.read_with_orm_mode = True # type: ignore


def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
if IS_PYDANTIC_V2:
return class_dict.get("__annotations__", {})
else:
return resolve_annotations(class_dict.get("__annotations__", {}),class_dict.get("__module__", None))
return resolve_annotations(
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
)


def is_table(class_dict: dict[str, Any]) -> bool:
config: SQLModelConfig = {}
Expand All @@ -125,6 +133,7 @@ def is_table(class_dict: dict[str, Any]) -> bool:
return kw_table
return False


def get_relationship_to(
name: str,
rel_info: "RelationshipInfo",
Expand Down Expand Up @@ -170,6 +179,7 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any])
"""
if IS_PYDANTIC_V2:
from .main import FieldInfo

# Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
for key in annotations.keys():
value = class_dict.get(key, PydanticUndefined)
Expand All @@ -180,9 +190,10 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any])
value.default in (PydanticUndefined, Ellipsis)
) and value.default_factory is None:
# So we can check for nullable
value.original_default = value.default
value.original_default = value.default
value.default = None


def is_field_noneable(field: "FieldInfo") -> bool:
if IS_PYDANTIC_V2:
if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined:
Expand All @@ -205,4 +216,4 @@ def is_field_noneable(field: "FieldInfo") -> bool:
return field.allow_none and (
field.shape != SHAPE_SINGLETON or not field.sub_fields
)
return False
return False
3 changes: 2 additions & 1 deletion sqlmodel/engine/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __next__(self) -> _T:

def first(self) -> Optional[_T]:
return super().first()

def one_or_none(self) -> Optional[_T]:
return super().one_or_none()

Expand Down Expand Up @@ -75,4 +76,4 @@ def one(self) -> _T: # type: ignore
return super().one() # type: ignore

def scalar(self) -> Optional[_T]:
return super().scalar()
return super().scalar()
2 changes: 1 addition & 1 deletion sqlmodel/ext/asyncio/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,4 @@ async def exec(
execution_options=execution_options,
bind_arguments=bind_arguments,
**kw,
)
)
39 changes: 18 additions & 21 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,26 @@
from .compat import (
IS_PYDANTIC_V2,
NoArgAnyCallable,
PydanticModelConfig,
PydanticUndefined,
PydanticUndefinedType,
SQLModelConfig,
get_annotations,
get_config_value,
get_model_fields,
get_relationship_to,
is_field_noneable,
is_table,
set_config_value,
set_empty_defaults,
set_fields_set,
is_table,
is_field_noneable,
PydanticModelConfig,
get_annotations
)
from .sql.sqltypes import GUID, AutoString

if not IS_PYDANTIC_V2:
from pydantic.errors import ConfigError, DictError
from pydantic.main import validate_model
from pydantic.utils import ROOT_KEY
from pydantic.typing import resolve_annotations

_T = TypeVar("_T")

Expand Down Expand Up @@ -444,8 +443,7 @@ def __new__(
) # skip dunder methods and attributes
}
config_kwargs = {
key: kwargs[key]
for key in kwargs.keys() & allowed_config_kwargs
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
}
config_table = is_table(class_dict)
if config_table:
Expand Down Expand Up @@ -690,7 +688,7 @@ def __init__(__pydantic_self__, **data: Any) -> None:
# settable attribute
if IS_PYDANTIC_V2:
old_dict = __pydantic_self__.__dict__.copy()
__pydantic_self__.super().__init__(**data) # noqa
__pydantic_self__.super().__init__(**data) # noqa
__pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__}
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
else:
Expand All @@ -699,7 +697,7 @@ def __init__(__pydantic_self__, **data: Any) -> None:
)
# Only raise errors if not a SQLModel model
if (
not getattr(__pydantic_self__.__config__, "table", False) # noqa
not getattr(__pydantic_self__.__config__, "table", False) # noqa
and validation_error
):
raise validation_error
Expand Down Expand Up @@ -764,7 +762,7 @@ def from_orm(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
# Duplicated from Pydantic
if not cls.__config__.orm_mode: # noqa: attr-defined
if not cls.__config__.orm_mode: # noqa: attr-defined
raise ConfigError(
"You must have the config attribute orm_mode=True to use from_orm"
)
Expand All @@ -777,7 +775,7 @@ def from_orm(
if update is not None:
obj = {**obj, **update}
# End SQLModel support dict
if not getattr(cls.__config__, "table", False): # noqa
if not getattr(cls.__config__, "table", False): # noqa
# If not table, normal Pydantic code
m: _TSQLModel = cls.__new__(cls)
else:
Expand All @@ -788,21 +786,21 @@ def from_orm(
if validation_error:
raise validation_error
# Updated to trigger SQLAlchemy internal handling
if not getattr(cls.__config__, "table", False): # noqa
if not getattr(cls.__config__, "table", False): # noqa
object.__setattr__(m, "__dict__", values)
else:
for key, value in values.items():
setattr(m, key, value)
# Continue with standard Pydantic logic
object.__setattr__(m, "__fields_set__", fields_set)
m._init_private_attributes() # noqa
m._init_private_attributes() # noqa
return m

@classmethod
def parse_obj(
cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
) -> _TSQLModel:
obj = cls._enforce_dict_if_root(obj) # noqa
obj = cls._enforce_dict_if_root(obj) # noqa
# SQLModel, support update dict
if update is not None:
obj = {**obj, **update}
Expand All @@ -814,7 +812,7 @@ def parse_obj(
def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel:
if isinstance(value, cls):
return (
value.copy() if cls.__config__.copy_on_model_validation else value # noqa
value.copy() if cls.__config__.copy_on_model_validation else value # noqa
)

value = cls._enforce_dict_if_root(value)
Expand All @@ -826,9 +824,9 @@ def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel:
# Reset fields set, this would have been done in Pydantic in __init__
object.__setattr__(model, "__fields_set__", fields_set)
return model
elif cls.__config__.orm_mode: # noqa
elif cls.__config__.orm_mode: # noqa
return cls.from_orm(value)
elif cls.__custom_root_type__: # noqa
elif cls.__custom_root_type__: # noqa
return cls.parse_obj(value)
else:
try:
Expand All @@ -852,20 +850,20 @@ def _calculate_keys(
# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
return (
self.__fields__.keys() # noqa
self.__fields__.keys() # noqa
) # | self.__sqlmodel_relationships__.keys()

keys: AbstractSet[str]
if exclude_unset:
keys = self.__fields_set__.copy() # noqa
keys = self.__fields_set__.copy() # noqa
else:
# Original in Pydantic:
# keys = self.__dict__.keys()
# Updated to not return SQLAlchemy attributes
# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
keys = (
self.__fields__.keys() # noqa
self.__fields__.keys() # noqa
) # | self.__sqlmodel_relationships__.keys()
if include is not None:
keys &= include.keys()
Expand All @@ -877,4 +875,3 @@ def _calculate_keys(
keys -= {k for k, v in exclude.items() if _value_items_is_true(v)}

return keys

2 changes: 1 addition & 1 deletion sqlmodel/orm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,4 @@ def get(
with_for_update=with_for_update,
identity_token=identity_token,
execution_options=execution_options,
)
)
2 changes: 1 addition & 1 deletion sqlmodel/sql/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,4 +434,4 @@ def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type:
def col(column_expression: Any) -> ColumnClause: # type: ignore
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
return column_expression
return column_expression
2 changes: 1 addition & 1 deletion sqlmodel/sql/sqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UU
else:
if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)
return cast(uuid.UUID, value)
return cast(uuid.UUID, value)
2 changes: 2 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .conftest import needs_pydanticv1, needs_pydanticv2


@needs_pydanticv1
def test_validation_pydantic_v1(clear_sqlmodel):
"""Test validation of implicit and explicit None values.
Expand Down Expand Up @@ -34,6 +35,7 @@ def reject_none(cls, v):
with pytest.raises(ValidationError):
Hero.from_orm({"name": None, "age": 25})


@needs_pydanticv2
def test_validation_pydantic_v2(clear_sqlmodel):
"""Test validation of implicit and explicit None values.
Expand Down

0 comments on commit 254fb13

Please sign in to comment.