diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/ranges.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/ranges.py | 206 |
1 files changed, 206 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 327feb409..6729f3785 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -8,10 +8,14 @@ from __future__ import annotations import dataclasses +from datetime import date +from datetime import datetime +from datetime import timedelta from typing import Any from typing import Generic from typing import Optional from typing import TypeVar +from typing import Union from ... import types as sqltypes from ...util import py310 @@ -80,6 +84,204 @@ class Range(Generic[_T]): def __bool__(self) -> bool: return self.empty + def _contains_value(self, value: _T) -> bool: + "Check whether this range contains the given `value`." + + if self.empty: + return False + + if self.lower is None: + return self.upper is None or ( + value < self.upper + if self.bounds[1] == ")" + else value <= self.upper + ) + + if self.upper is None: + return ( + value > self.lower + if self.bounds[0] == "(" + else value >= self.lower + ) + + return ( + value > self.lower + if self.bounds[0] == "(" + else value >= self.lower + ) and ( + value < self.upper + if self.bounds[1] == ")" + else value <= self.upper + ) + + def _get_discrete_step(self): + "Determine the “step” for this range, if it is a discrete one." + + # See + # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE + # for the rationale + + if isinstance(self.lower, int) or isinstance(self.upper, int): + return 1 + elif isinstance(self.lower, datetime) or isinstance( + self.upper, datetime + ): + # This is required, because a `isinstance(datetime.now(), date)` + # is True + return None + elif isinstance(self.lower, date) or isinstance(self.upper, date): + return timedelta(days=1) + else: + return None + + def contained_by(self, other: Range) -> bool: + "Determine whether this range is a contained by `other`." + + # Any range contains the empty one + if self.empty: + return True + + # An empty range does not contain any range except the empty one + if other.empty: + return False + + olower = other.lower + oupper = other.upper + + # A bilateral unbound range contains any other range + if olower is oupper is None: + return True + + slower = self.lower + supper = self.upper + + # A lower-bound range cannot contain a lower-unbound range + if slower is None and olower is not None: + return False + + # Likewise on the right side + if supper is None and oupper is not None: + return False + + slower_inc = self.bounds[0] == "[" + supper_inc = self.bounds[1] == "]" + olower_inc = other.bounds[0] == "[" + oupper_inc = other.bounds[1] == "]" + + # Check the lower end + step = -1 + if slower is not None and olower is not None: + lside = olower < slower + if not lside: + if not slower_inc or olower_inc: + lside = olower == slower + if not lside: + # Cover (1,x] vs [2,x) and (0,x] vs [1,x) + if not slower_inc and olower_inc and slower < olower: + step = self._get_discrete_step() + if step is not None: + lside = olower == (slower + step) + elif slower_inc and not olower_inc and slower > olower: + step = self._get_discrete_step() + if step is not None: + lside = (olower + step) == slower + if not lside: + return False + + # Lower end already considered, an upper-unbound range surely contains + # this + if oupper is None: + return True + + # Check the upper end + uside = oupper > supper + if not uside: + if not supper_inc or oupper_inc: + uside = oupper == supper + if not uside: + # Cover (x,2] vs [x,3) and (x,1] vs [x,2) + if supper_inc and not oupper_inc and supper < oupper: + if step == -1: + step = self._get_discrete_step() + if step is not None: + uside = oupper == (supper + step) + elif not supper_inc and oupper_inc and supper > oupper: + if step == -1: + step = self._get_discrete_step() + if step is not None: + uside = (oupper + step) == supper + return uside + + def contains(self, value: Union[_T, Range]) -> bool: + "Determine whether this range contains `value`." + + if isinstance(value, Range): + return value.contained_by(self) + else: + return self._contains_value(value) + + def overlaps(self, other): + """Boolean expression. Returns true if the column overlaps + (has points in common with) the right hand operand. + """ + raise NotImplementedError("not yet implemented") + + def strictly_left_of(self, other): + """Boolean expression. Returns true if the column is strictly + left of the right hand operand. + """ + raise NotImplementedError("not yet implemented") + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other): + """Boolean expression. Returns true if the column is strictly + right of the right hand operand. + """ + raise NotImplementedError("not yet implemented") + + __rshift__ = strictly_right_of + + def not_extend_right_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend right of the range in the operand. + """ + raise NotImplementedError("not yet implemented") + + def not_extend_left_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend left of the range in the operand. + """ + raise NotImplementedError("not yet implemented") + + def adjacent_to(self, other): + """Boolean expression. Returns true if the range in the column + is adjacent to the range in the operand. + """ + raise NotImplementedError("not yet implemented") + + def __add__(self, other): + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + raise NotImplementedError("not yet implemented") + + def __str__(self): + return self._stringify() + + def _stringify(self): + if self.empty: + return "empty" + + l, r = self.lower, self.upper + l = "" if l is None else l + r = "" if r is None else r + + b0, b1 = self.bounds + + return f"{b0}{l},{r}{b1}" + class AbstractRange(sqltypes.TypeEngine): """ @@ -93,6 +295,8 @@ class AbstractRange(sqltypes.TypeEngine): render_bind_cast = True + __abstract__ = True + def adapt(self, impltype): """dynamically adapt a range type to an abstract impl. @@ -202,6 +406,8 @@ class AbstractRangeImpl(AbstractRange): class AbstractMultiRange(AbstractRange): """base for PostgreSQL MULTIRANGE types""" + __abstract__ = True + class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange): """marker for AbstractRange that will apply a subclass-specific |