summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/_py_collections.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-01-09 11:49:02 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-01-24 15:14:01 -0500
commitff1ab665cb1694b85085680d1a02c7c11fa2a6d4 (patch)
treebb8414b44946d9cb96361d7dcd4a4541d8254672 /lib/sqlalchemy/util/_py_collections.py
parentaba3ab247da4628e4e7baf993702e2efaccbc547 (diff)
downloadsqlalchemy-ff1ab665cb1694b85085680d1a02c7c11fa2a6d4.tar.gz
mypy: sqlalchemy.util
Starting to set up practices and conventions to get the library typed. Key goals for typing are: 1. whole library can pass mypy without any strict turned on. 2. we can incrementally turn on some strict flags on a per-package/ module basis, as here we turn on more strictness for sqlalchemy.util, exc, and log 3. mypy ORM plugin tests work fully without sqlalchemy2-stubs installed 4. public facing methods all have return types, major parameter signatures filled in also 5. Foundational elements like util etc. are typed enough so that we can use them in fully typed internals higher up the stack. Conventions set up here: 1. we can use lots of config in setup.cfg to limit where mypy is throwing errors and how detailed it should be in different packages / modules. We can use this to push up gerrits that will pass tests fully without everything being typed. 2. a new tox target pep484 is added. this links to a new jenkins pep484 job that works across all projects (alembic, dogpile, etc.) We've worked around some mypy bugs that will likely be around for awhile, and also set up some core practices for how to deal with certain things such as public_factory modules (mypy won't accept a module from a callable at all, so need to use simple type checking conditionals). References: #6810 Change-Id: I80be58029896a29fd9f491aa3215422a8b705e12
Diffstat (limited to 'lib/sqlalchemy/util/_py_collections.py')
-rw-r--r--lib/sqlalchemy/util/_py_collections.py168
1 files changed, 119 insertions, 49 deletions
diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py
index ff61f6ca9..a4e4b8b5d 100644
--- a/lib/sqlalchemy/util/_py_collections.py
+++ b/lib/sqlalchemy/util/_py_collections.py
@@ -1,17 +1,52 @@
from itertools import filterfalse
+from typing import Any
+from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import TypeVar
+
+_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
class ImmutableContainer:
- def _immutable(self, *arg, **kw):
+ def _immutable(self, *arg: Any, **kw: Any) -> NoReturn:
raise TypeError("%s object is immutable" % self.__class__.__name__)
- __delitem__ = __setitem__ = __setattr__ = _immutable
+ def __delitem__(self, key: Any) -> NoReturn:
+ self._immutable()
+ def __setitem__(self, key: Any, value: Any) -> NoReturn:
+ self._immutable()
-class immutabledict(ImmutableContainer, dict):
+ def __setattr__(self, key: str, value: Any) -> NoReturn:
+ self._immutable()
- clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
+class ImmutableDictBase(ImmutableContainer, Dict[_KT, _VT]):
+ def clear(self) -> NoReturn:
+ self._immutable()
+
+ def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn:
+ self._immutable()
+
+ def popitem(self) -> NoReturn:
+ self._immutable()
+
+ def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn:
+ self._immutable()
+
+ def update(self, *arg: Any, **kw: Any) -> NoReturn:
+ self._immutable()
+
+
+class immutabledict(ImmutableDictBase[_KT, _VT]):
def __new__(cls, *args):
new = dict.__new__(cls)
dict.__init__(new, *args)
@@ -41,7 +76,7 @@ class immutabledict(ImmutableContainer, dict):
dict.__init__(new, self)
if __d:
dict.update(new, __d)
- dict.update(new, kw)
+ dict.update(new, kw) # type: ignore
return new
def merge_with(self, *dicts):
@@ -61,110 +96,145 @@ class immutabledict(ImmutableContainer, dict):
return "immutabledict(%s)" % dict.__repr__(self)
-class OrderedSet(set):
+class OrderedSet(Generic[_T]):
+ __slots__ = ("_list", "_set", "__weakref__")
+
+ _list: List[_T]
+ _set: Set[_T]
+
def __init__(self, d=None):
- set.__init__(self)
if d is not None:
self._list = unique_list(d)
- set.update(self, self._list)
+ self._set = set(self._list)
else:
self._list = []
+ self._set = set()
+
+ def __reduce__(self):
+ return (OrderedSet, (self._list,))
- def add(self, element):
+ def add(self, element: _T) -> None:
if element not in self:
self._list.append(element)
- set.add(self, element)
+ self._set.add(element)
- def remove(self, element):
- set.remove(self, element)
+ def remove(self, element: _T) -> None:
+ self._set.remove(element)
self._list.remove(element)
- def insert(self, pos, element):
+ def insert(self, pos: int, element: _T) -> None:
if element not in self:
self._list.insert(pos, element)
- set.add(self, element)
+ self._set.add(element)
- def discard(self, element):
+ def discard(self, element: _T) -> None:
if element in self:
self._list.remove(element)
- set.remove(self, element)
+ self._set.remove(element)
- def clear(self):
- set.clear(self)
+ def clear(self) -> None:
+ self._set.clear()
self._list = []
- def __getitem__(self, key):
+ def __len__(self) -> int:
+ return len(self._set)
+
+ def __eq__(self, other):
+ if not isinstance(other, OrderedSet):
+ return self._set == other
+ else:
+ return self._set == other._set
+
+ def __ne__(self, other):
+ if not isinstance(other, OrderedSet):
+ return self._set != other
+ else:
+ return self._set != other._set
+
+ def __contains__(self, element: Any) -> bool:
+ return element in self._set
+
+ def __getitem__(self, key: int) -> _T:
return self._list[key]
- def __iter__(self):
+ def __iter__(self) -> Iterator[_T]:
return iter(self._list)
- def __add__(self, other):
+ def __add__(self, other: Iterator[_T]) -> "OrderedSet[_T]":
return self.union(other)
- def __repr__(self):
+ def __repr__(self) -> str:
return "%s(%r)" % (self.__class__.__name__, self._list)
__str__ = __repr__
- def update(self, iterable):
- for e in iterable:
- if e not in self:
- self._list.append(e)
- set.add(self, e)
- return self
+ def update(self, *iterables: Iterable[_T]) -> None:
+ for iterable in iterables:
+ for e in iterable:
+ if e not in self:
+ self._list.append(e)
+ self._set.add(e)
- __ior__ = update
+ def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.update(other)
+ return self
- def union(self, other):
+ def union(self, other: Iterable[_T]) -> "OrderedSet[_T]":
result = self.__class__(self)
result.update(other)
return result
- __or__ = union
+ def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.union(other)
- def intersection(self, other):
+ def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]":
other = other if isinstance(other, set) else set(other)
return self.__class__(a for a in self if a in other)
- __and__ = intersection
+ def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.intersection(other)
- def symmetric_difference(self, other):
+ def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
other_set = other if isinstance(other, set) else set(other)
result = self.__class__(a for a in self if a not in other_set)
result.update(a for a in other if a not in self)
return result
- __xor__ = symmetric_difference
+ def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.symmetric_difference(other)
- def difference(self, other):
+ def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
other = other if isinstance(other, set) else set(other)
return self.__class__(a for a in self if a not in other)
- __sub__ = difference
+ def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.difference(other)
- def intersection_update(self, other):
+ def intersection_update(self, other: Iterable[_T]) -> None:
other = other if isinstance(other, set) else set(other)
- set.intersection_update(self, other)
+ self._set.intersection_update(other)
self._list = [a for a in self._list if a in other]
- return self
- __iand__ = intersection_update
+ def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.intersection_update(other)
+ return self
- def symmetric_difference_update(self, other):
- set.symmetric_difference_update(self, other)
+ def symmetric_difference_update(self, other: Iterable[_T]) -> None:
+ self._set.symmetric_difference_update(other)
self._list = [a for a in self._list if a in self]
self._list += [a for a in other if a in self]
- return self
- __ixor__ = symmetric_difference_update
+ def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.symmetric_difference_update(other)
+ return self
- def difference_update(self, other):
- set.difference_update(self, other)
+ def difference_update(self, other: Iterable[_T]) -> None:
+ self._set.difference_update(other)
self._list = [a for a in self._list if a in self]
- return self
- __isub__ = difference_update
+ def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.difference_update(other)
+ return self
class IdentitySet: