summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/ranges.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/ranges.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ranges.py206
1 files changed, 206 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py
index 327feb409..6729f3785 100644
--- a/lib/sqlalchemy/dialects/postgresql/ranges.py
+++ b/lib/sqlalchemy/dialects/postgresql/ranges.py
@@ -8,10 +8,14 @@
from __future__ import annotations
import dataclasses
+from datetime import date
+from datetime import datetime
+from datetime import timedelta
from typing import Any
from typing import Generic
from typing import Optional
from typing import TypeVar
+from typing import Union
from ... import types as sqltypes
from ...util import py310
@@ -80,6 +84,204 @@ class Range(Generic[_T]):
def __bool__(self) -> bool:
return self.empty
+ def _contains_value(self, value: _T) -> bool:
+ "Check whether this range contains the given `value`."
+
+ if self.empty:
+ return False
+
+ if self.lower is None:
+ return self.upper is None or (
+ value < self.upper
+ if self.bounds[1] == ")"
+ else value <= self.upper
+ )
+
+ if self.upper is None:
+ return (
+ value > self.lower
+ if self.bounds[0] == "("
+ else value >= self.lower
+ )
+
+ return (
+ value > self.lower
+ if self.bounds[0] == "("
+ else value >= self.lower
+ ) and (
+ value < self.upper
+ if self.bounds[1] == ")"
+ else value <= self.upper
+ )
+
+ def _get_discrete_step(self):
+ "Determine the “step” for this range, if it is a discrete one."
+
+ # See
+ # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE
+ # for the rationale
+
+ if isinstance(self.lower, int) or isinstance(self.upper, int):
+ return 1
+ elif isinstance(self.lower, datetime) or isinstance(
+ self.upper, datetime
+ ):
+ # This is required, because a `isinstance(datetime.now(), date)`
+ # is True
+ return None
+ elif isinstance(self.lower, date) or isinstance(self.upper, date):
+ return timedelta(days=1)
+ else:
+ return None
+
+ def contained_by(self, other: Range) -> bool:
+ "Determine whether this range is a contained by `other`."
+
+ # Any range contains the empty one
+ if self.empty:
+ return True
+
+ # An empty range does not contain any range except the empty one
+ if other.empty:
+ return False
+
+ olower = other.lower
+ oupper = other.upper
+
+ # A bilateral unbound range contains any other range
+ if olower is oupper is None:
+ return True
+
+ slower = self.lower
+ supper = self.upper
+
+ # A lower-bound range cannot contain a lower-unbound range
+ if slower is None and olower is not None:
+ return False
+
+ # Likewise on the right side
+ if supper is None and oupper is not None:
+ return False
+
+ slower_inc = self.bounds[0] == "["
+ supper_inc = self.bounds[1] == "]"
+ olower_inc = other.bounds[0] == "["
+ oupper_inc = other.bounds[1] == "]"
+
+ # Check the lower end
+ step = -1
+ if slower is not None and olower is not None:
+ lside = olower < slower
+ if not lside:
+ if not slower_inc or olower_inc:
+ lside = olower == slower
+ if not lside:
+ # Cover (1,x] vs [2,x) and (0,x] vs [1,x)
+ if not slower_inc and olower_inc and slower < olower:
+ step = self._get_discrete_step()
+ if step is not None:
+ lside = olower == (slower + step)
+ elif slower_inc and not olower_inc and slower > olower:
+ step = self._get_discrete_step()
+ if step is not None:
+ lside = (olower + step) == slower
+ if not lside:
+ return False
+
+ # Lower end already considered, an upper-unbound range surely contains
+ # this
+ if oupper is None:
+ return True
+
+ # Check the upper end
+ uside = oupper > supper
+ if not uside:
+ if not supper_inc or oupper_inc:
+ uside = oupper == supper
+ if not uside:
+ # Cover (x,2] vs [x,3) and (x,1] vs [x,2)
+ if supper_inc and not oupper_inc and supper < oupper:
+ if step == -1:
+ step = self._get_discrete_step()
+ if step is not None:
+ uside = oupper == (supper + step)
+ elif not supper_inc and oupper_inc and supper > oupper:
+ if step == -1:
+ step = self._get_discrete_step()
+ if step is not None:
+ uside = (oupper + step) == supper
+ return uside
+
+ def contains(self, value: Union[_T, Range]) -> bool:
+ "Determine whether this range contains `value`."
+
+ if isinstance(value, Range):
+ return value.contained_by(self)
+ else:
+ return self._contains_value(value)
+
+ def overlaps(self, other):
+ """Boolean expression. Returns true if the column overlaps
+ (has points in common with) the right hand operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def strictly_left_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ left of the right hand operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ __lshift__ = strictly_left_of
+
+ def strictly_right_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ right of the right hand operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ __rshift__ = strictly_right_of
+
+ def not_extend_right_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend right of the range in the operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def not_extend_left_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend left of the range in the operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def adjacent_to(self, other):
+ """Boolean expression. Returns true if the range in the column
+ is adjacent to the range in the operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def __add__(self, other):
+ """Range expression. Returns the union of the two ranges.
+ Will raise an exception if the resulting range is not
+ contiguous.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def __str__(self):
+ return self._stringify()
+
+ def _stringify(self):
+ if self.empty:
+ return "empty"
+
+ l, r = self.lower, self.upper
+ l = "" if l is None else l
+ r = "" if r is None else r
+
+ b0, b1 = self.bounds
+
+ return f"{b0}{l},{r}{b1}"
+
class AbstractRange(sqltypes.TypeEngine):
"""
@@ -93,6 +295,8 @@ class AbstractRange(sqltypes.TypeEngine):
render_bind_cast = True
+ __abstract__ = True
+
def adapt(self, impltype):
"""dynamically adapt a range type to an abstract impl.
@@ -202,6 +406,8 @@ class AbstractRangeImpl(AbstractRange):
class AbstractMultiRange(AbstractRange):
"""base for PostgreSQL MULTIRANGE types"""
+ __abstract__ = True
+
class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange):
"""marker for AbstractRange that will apply a subclass-specific