Skip to content

Issue 6810 type annotations for sqlalchemy.orm.mapped_collection #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 97 additions & 39 deletions lib/sqlalchemy/orm/mapped_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors

from __future__ import annotations

from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar

from . import base
Expand All @@ -22,11 +23,24 @@
from ..sql import expression
from ..sql import roles

if TYPE_CHECKING:
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

from . import AttributeEventToken
from . import Mapper
from ..sql.elements import ColumnElement

_KT = TypeVar("_KT", bound=Any)
_VT = TypeVar("_VT", bound=Any)

_F = TypeVar("_F", bound=Callable[[Any], Any])

class _PlainColumnGetter:

class _PlainColumnGetter(Generic[_KT]):
"""Plain column getter, stores collection of Column objects
directly.

Expand All @@ -38,21 +52,26 @@ class _PlainColumnGetter:

__slots__ = ("cols", "composite")

def __init__(self, cols):
def __init__(self, cols: Sequence[ColumnElement[_KT]]) -> None:
self.cols = cols
self.composite = len(cols) > 1

def __reduce__(self):
def __reduce__(
self,
) -> Tuple[
Type[_SerializableColumnGetterV2[_KT]],
Tuple[Sequence[Tuple[Optional[str], Optional[str]]]],
]:
return _SerializableColumnGetterV2._reduce_from_cols(self.cols)

def _cols(self, mapper):
def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]:
return self.cols

def __call__(self, value):
def __call__(self, value: _KT) -> Union[_KT, Tuple[_KT, ...]]:
state = base.instance_state(value)
m = base._state_mapper(state)

key = [
key: List[_KT] = [
m._get_state_attr_by_column(state, state.dict, col)
for col in self._cols(m)
]
Expand All @@ -62,7 +81,7 @@ def __call__(self, value):
return key[0]


class _SerializableColumnGetterV2(_PlainColumnGetter):
class _SerializableColumnGetterV2(_PlainColumnGetter[_KT]):
"""Updated serializable getter which deals with
multi-table mapped classes.

Expand All @@ -76,38 +95,52 @@ class _SerializableColumnGetterV2(_PlainColumnGetter):

__slots__ = ("colkeys",)

def __init__(self, colkeys):
def __init__(
self, colkeys: Sequence[Tuple[Optional[str], Optional[str]]]
) -> None:
self.colkeys = colkeys
self.composite = len(colkeys) > 1

def __reduce__(self):
def __reduce__(
self,
) -> Tuple[
Type[_SerializableColumnGetterV2[_KT]],
Tuple[Sequence[Tuple[Optional[str], Optional[str]]]],
]:
return self.__class__, (self.colkeys,)

@classmethod
def _reduce_from_cols(cls, cols):
def _table_key(c):
def _reduce_from_cols(
cls, cols: Sequence[ColumnElement[_KT]]
) -> Tuple[
Type[_SerializableColumnGetterV2[_KT]],
Tuple[Sequence[Tuple[Optional[str], Optional[str]]]],
]:
def _table_key(c: ColumnElement[_KT]) -> Optional[str]:
if not isinstance(c.table, expression.TableClause):
return None
else:
return c.table.key
return c.table.key # type: ignore

colkeys = [(c.key, _table_key(c)) for c in cols]
return _SerializableColumnGetterV2, (colkeys,)

def _cols(self, mapper):
cols = []
def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]:
cols: List[ColumnElement[_KT]] = []
metadata = getattr(mapper.local_table, "metadata", None)
for (ckey, tkey) in self.colkeys:
if tkey is None or metadata is None or tkey not in metadata:
cols.append(mapper.local_table.c[ckey])
cols.append(mapper.local_table.c[ckey]) # type: ignore
else:
cols.append(metadata.tables[tkey].c[ckey])
return cols


def column_keyed_dict(
mapping_spec, *, ignore_unpopulated_attribute: bool = False
):
mapping_spec: Union[Type[_KT], Callable[[_KT], _VT]],
*,
ignore_unpopulated_attribute: bool = False,
) -> Type[KeyFuncDict[_KT, _KT]]:
"""A dictionary-based collection type with column-based keying.

.. versionchanged:: 2.0 Renamed :data:`.column_mapped_collection` to
Expand Down Expand Up @@ -155,7 +188,8 @@ def column_keyed_dict(
]
keyfunc = _PlainColumnGetter(cols)
return _mapped_collection_cls(
keyfunc, ignore_unpopulated_attribute=ignore_unpopulated_attribute
keyfunc,
ignore_unpopulated_attribute=ignore_unpopulated_attribute,
)


Expand All @@ -169,13 +203,13 @@ def __call__(self, mapped_object: Any) -> Any:
dict_ = base.instance_dict(mapped_object)
return dict_.get(self.attr_name, base.NO_VALUE)

def __reduce__(self):
def __reduce__(self) -> Tuple[Type[_AttrGetter], Tuple[str]]:
return _AttrGetter, (self.attr_name,)


def attribute_keyed_dict(
attr_name: str, *, ignore_unpopulated_attribute: bool = False
) -> Type[KeyFuncDict]:
) -> Type[KeyFuncDict[_KT, _KT]]:
"""A dictionary-based collection type with attribute-based keying.

.. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to
Expand Down Expand Up @@ -223,7 +257,7 @@ def attribute_keyed_dict(


def keyfunc_mapping(
keyfunc: Callable[[Any], _KT],
keyfunc: _F,
*,
ignore_unpopulated_attribute: bool = False,
) -> Type[KeyFuncDict[_KT, Any]]:
Expand Down Expand Up @@ -297,7 +331,12 @@ class KeyFuncDict(Dict[_KT, _VT]):

"""

def __init__(self, keyfunc, *, ignore_unpopulated_attribute=False):
def __init__(
self,
keyfunc: _F,
*,
ignore_unpopulated_attribute: bool = False,
) -> None:
"""Create a new collection with keying provided by keyfunc.

keyfunc may be any callable that takes an object and returns an object
Expand All @@ -315,21 +354,30 @@ def __init__(self, keyfunc, *, ignore_unpopulated_attribute=False):
self.ignore_unpopulated_attribute = ignore_unpopulated_attribute

@classmethod
def _unreduce(cls, keyfunc, values):
mp = KeyFuncDict(keyfunc)
def _unreduce(
cls, keyfunc: _F, values: Dict[_KT, _KT]
) -> "KeyFuncDict[_KT, _KT]":
mp: KeyFuncDict[_KT, _KT] = KeyFuncDict(keyfunc)
mp.update(values)
return mp

def __reduce__(self):
def __reduce__(
self,
) -> Tuple[
Callable[[_KT, _KT], KeyFuncDict[_KT, _KT]],
Tuple[Any, Union[Dict[_KT, _KT], Dict[_KT, _KT]]],
]:
return (KeyFuncDict._unreduce, (self.keyfunc, dict(self)))

def _raise_for_unpopulated(self, value, initiator):
def _raise_for_unpopulated(
self, value: _KT, initiator: Optional[AttributeEventToken]
) -> None:
mapper = base.instance_state(value).mapper

if initiator is None:
relationship = "unknown relationship"
else:
relationship = mapper.attrs[initiator.key]
relationship = f"{mapper.attrs[initiator.key]}"

raise sa_exc.InvalidRequestError(
f"In event triggered from population of attribute {relationship} "
Expand All @@ -345,9 +393,13 @@ def _raise_for_unpopulated(self, value, initiator):
f"parameter on the mapped collection factory."
)

@collection.appender
@collection.internally_instrumented
def set(self, value, _sa_initiator=None):
@collection.appender # type: ignore[misc]
@collection.internally_instrumented # type: ignore[misc]
def set(
self,
value: _KT,
_sa_initiator: Optional[AttributeEventToken] = None,
) -> None:
"""Add an item by value, consulting the keyfunc for the key."""

key = self.keyfunc(value)
Expand All @@ -358,11 +410,15 @@ def set(self, value, _sa_initiator=None):
else:
return

self.__setitem__(key, value, _sa_initiator)
self.__setitem__(key, value, _sa_initiator) # type: ignore[call-arg]

@collection.remover
@collection.internally_instrumented
def remove(self, value, _sa_initiator=None):
@collection.remover # type: ignore[misc]
@collection.internally_instrumented # type: ignore[misc]
def remove(
self,
value: _KT,
_sa_initiator: Optional[AttributeEventToken] = None,
) -> None:
"""Remove an item by value, consulting the keyfunc for the key."""

key = self.keyfunc(value)
Expand All @@ -381,12 +437,14 @@ def remove(self, value, _sa_initiator=None):
"based on mutable properties or properties that only obtain "
"values after flush?" % (value, self[key], key)
)
self.__delitem__(key, _sa_initiator)
self.__delitem__(key, _sa_initiator) # type: ignore[call-arg]


def _mapped_collection_cls(keyfunc, ignore_unpopulated_attribute):
class _MKeyfuncMapped(KeyFuncDict):
def __init__(self):
def _mapped_collection_cls(
keyfunc: _F, ignore_unpopulated_attribute: bool
) -> Type[KeyFuncDict[_KT, _KT]]:
class _MKeyfuncMapped(KeyFuncDict[_KT, _KT]):
def __init__(self) -> None:
super().__init__(
keyfunc,
ignore_unpopulated_attribute=ignore_unpopulated_attribute,
Expand Down