From 4750700665a5f2cfc29d401fefc036681cda7d2d Mon Sep 17 00:00:00 2001 From: Jelmer Draaijer Date: Wed, 14 Feb 2024 11:10:49 +0100 Subject: [PATCH] Get primary key fields from mapper --- starlette_admin/contrib/sqla/view.py | 47 ++++++++++++++++++---------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/starlette_admin/contrib/sqla/view.py b/starlette_admin/contrib/sqla/view.py index 90558591..4f20bf9d 100644 --- a/starlette_admin/contrib/sqla/view.py +++ b/starlette_admin/contrib/sqla/view.py @@ -110,10 +110,14 @@ def __init__( self.fields = (converter or ModelConverter()).convert_fields_list( fields=self.fields, model=self.model, mapper=mapper ) - self._setup_primary_key() + self._setup_primary_key(mapper) self.exclude_fields_from_list = normalize_list(self.exclude_fields_from_list) # type: ignore - self.exclude_fields_from_detail = normalize_list(self.exclude_fields_from_detail) # type: ignore - self.exclude_fields_from_create = normalize_list(self.exclude_fields_from_create) # type: ignore + self.exclude_fields_from_detail = normalize_list( + self.exclude_fields_from_detail + ) # type: ignore + self.exclude_fields_from_create = normalize_list( + self.exclude_fields_from_create + ) # type: ignore self.exclude_fields_from_edit = normalize_list(self.exclude_fields_from_edit) # type: ignore _default_list = [ field.name @@ -136,19 +140,20 @@ def __init__( ) super().__init__() - def _setup_primary_key(self) -> None: + def _setup_primary_key(self, mapper: Mapper) -> None: # Detect the primary key attribute(s) of the model - _pk_attrs = [] + self._pk_column: Union[ Tuple[InstrumentedAttribute, ...], InstrumentedAttribute ] = () self._pk_coerce: Union[Tuple[type, ...], type] = () - for key in self.model.__dict__: - attr = getattr(self.model, key) - if isinstance(attr, InstrumentedAttribute) and getattr( - attr, "primary_key", False - ): - _pk_attrs.append(key) + + # mapper._primary_key_propkeys but then ordered by occurrence in the model + pks_by_table = [mapper._pks_by_table[table] for table in mapper.tables] # type: ignore[attr-defined] + _pk_attrs: List[str] = [] + for table in pks_by_table: + _pk_attrs += [mapper._columntoproperty[c].key for c in table] # type: ignore[attr-defined] + if len(_pk_attrs) > 1: self._pk_column = tuple(getattr(self.model, attr) for attr in _pk_attrs) self._pk_coerce = tuple( @@ -301,7 +306,8 @@ async def find_all( if isinstance(session, AsyncSession): return (await session.execute(stmt)).scalars().unique().all() return ( - (await anyio.to_thread.run_sync(session.execute, stmt)) # type: ignore[arg-type] + (await anyio.to_thread.run_sync(session.execute, stmt)) + # type: ignore[arg-type] .scalars() .unique() .all() @@ -327,7 +333,9 @@ async def find_by_pk(self, request: Request, pk: Any) -> Any: == (_pk == "True") # to avoid bool("False") which is True ) for _pk_col, _coerce, _pk in zip( - self._pk_column, self._pk_coerce, iterdecode(pk) # type: ignore[type-var,arg-type] + self._pk_column, + self._pk_coerce, + iterdecode(pk), # type: ignore[type-var,arg-type] ) ) else: @@ -340,7 +348,8 @@ async def find_by_pk(self, request: Request, pk: Any) -> Any: if isinstance(session, AsyncSession): return (await session.execute(stmt)).scalars().unique().one_or_none() return ( - (await anyio.to_thread.run_sync(session.execute, stmt)) # type: ignore[arg-type] + (await anyio.to_thread.run_sync(session.execute, stmt)) + # type: ignore[arg-type] .scalars() .unique() .one_or_none() @@ -376,7 +385,8 @@ async def _exec_find_by_pks( if isinstance(session, AsyncSession): return (await session.execute(stmt)).scalars().unique().all() return ( - (await anyio.to_thread.run_sync(session.execute, stmt)) # type: ignore[arg-type] + (await anyio.to_thread.run_sync(session.execute, stmt)) + # type: ignore[arg-type] .scalars() .unique() .all() @@ -411,7 +421,8 @@ async def _get_multiple_pks_in_clause( tuple( (_coerce(_pk) if _coerce is not bool else _pk == "True") for _coerce, _pk in zip( - self._pk_coerce, decoded_pk # type: ignore[type-var,arg-type] + self._pk_coerce, + decoded_pk, # type: ignore[type-var,arg-type] ) ) for decoded_pk in decoded_pks @@ -427,7 +438,9 @@ async def _get_multiple_pks_in_clause( else (_pk_col == (_pk == "True")) ) # to avoid bool("False") which is True for _pk_col, _coerce, _pk in zip( - self._pk_column, self._pk_coerce, decoded_pk # type: ignore[type-var,arg-type] + self._pk_column, + self._pk_coerce, + decoded_pk, # type: ignore[type-var,arg-type] ) ) )