summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_20/8706.rst11
-rw-r--r--doc/build/changelog/whatsnew_20.rst9
-rw-r--r--doc/build/dialects/postgresql.rst7
-rw-r--r--lib/sqlalchemy/dialects/postgresql/__init__.py2
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ranges.py206
-rw-r--r--test/dialect/postgresql/test_types.py270
6 files changed, 483 insertions, 22 deletions
diff --git a/doc/build/changelog/unreleased_20/8706.rst b/doc/build/changelog/unreleased_20/8706.rst
new file mode 100644
index 000000000..a6f3321b6
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/8706.rst
@@ -0,0 +1,11 @@
+.. change::
+ :tags: feature, postgresql
+ :tickets: 8706
+
+ Added new methods :meth:`_postgresql.Range.contains` and
+ :meth:`_postgresql.Range.contained_by` to the new :class:`.Range` data
+ object, which mirror the behavior of the PostgreSQL ``@>`` and ``<@``
+ operators, as well as the
+ :meth:`_postgresql.AbstractRange.comparator_factory.contains` and
+ :meth:`_postgresql.AbstractRange.comparator_factory.contained_by` SQL
+ operator methods. Pull request courtesy Lele Gaifax.
diff --git a/doc/build/changelog/whatsnew_20.rst b/doc/build/changelog/whatsnew_20.rst
index 98865f233..3d4eca6b2 100644
--- a/doc/build/changelog/whatsnew_20.rst
+++ b/doc/build/changelog/whatsnew_20.rst
@@ -1883,6 +1883,12 @@ objects are used.
Code that used the previous psycopg2-specific types should be modified
to use :class:`_postgresql.Range`, which presents a compatible interface.
+The :class:`_postgresql.Range` object also features comparison support which
+mirrors that of PostgreSQL. Implemented so far are :meth:`_postgresql.Range.contains`
+and :meth:`_postgresql.Range.contained_by` methods which work in the same way as
+the PostgreSQL ``@>`` and ``<@``. Additional operator support may be added
+in future releases.
+
See the documentation at :ref:`postgresql_ranges` for background on
using the new feature.
@@ -1891,6 +1897,9 @@ using the new feature.
:ref:`postgresql_ranges`
+:ticket:`7156`
+:ticket:`8706`
+
.. _change_7086:
``match()`` operator on PostgreSQL uses ``plainto_tsquery()`` rather than ``to_tsquery()``
diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst
index 411037f87..ef853b683 100644
--- a/doc/build/dialects/postgresql.rst
+++ b/doc/build/dialects/postgresql.rst
@@ -216,6 +216,7 @@ The available range datatypes are as follows:
* :class:`_postgresql.TSTZRANGE`
.. autoclass:: sqlalchemy.dialects.postgresql.Range
+ :members:
Multiranges
^^^^^^^^^^^
@@ -350,6 +351,12 @@ construction arguments, are as follows:
.. currentmodule:: sqlalchemy.dialects.postgresql
+.. autoclass:: sqlalchemy.dialects.postgresql.AbstractRange
+ :members: comparator_factory
+
+.. autoclass:: sqlalchemy.dialects.postgresql.AbstractMultiRange
+ :members: comparator_factory
+
.. autoclass:: aggregate_order_by
.. autoclass:: array
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py
index 8dbee1f7f..7890541ff 100644
--- a/lib/sqlalchemy/dialects/postgresql/__init__.py
+++ b/lib/sqlalchemy/dialects/postgresql/__init__.py
@@ -48,6 +48,8 @@ from .named_types import DropDomainType
from .named_types import DropEnumType
from .named_types import ENUM
from .named_types import NamedType
+from .ranges import AbstractMultiRange
+from .ranges import AbstractRange
from .ranges import DATEMULTIRANGE
from .ranges import DATERANGE
from .ranges import INT4MULTIRANGE
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py
index 327feb409..6729f3785 100644
--- a/lib/sqlalchemy/dialects/postgresql/ranges.py
+++ b/lib/sqlalchemy/dialects/postgresql/ranges.py
@@ -8,10 +8,14 @@
from __future__ import annotations
import dataclasses
+from datetime import date
+from datetime import datetime
+from datetime import timedelta
from typing import Any
from typing import Generic
from typing import Optional
from typing import TypeVar
+from typing import Union
from ... import types as sqltypes
from ...util import py310
@@ -80,6 +84,204 @@ class Range(Generic[_T]):
def __bool__(self) -> bool:
return self.empty
+ def _contains_value(self, value: _T) -> bool:
+ "Check whether this range contains the given `value`."
+
+ if self.empty:
+ return False
+
+ if self.lower is None:
+ return self.upper is None or (
+ value < self.upper
+ if self.bounds[1] == ")"
+ else value <= self.upper
+ )
+
+ if self.upper is None:
+ return (
+ value > self.lower
+ if self.bounds[0] == "("
+ else value >= self.lower
+ )
+
+ return (
+ value > self.lower
+ if self.bounds[0] == "("
+ else value >= self.lower
+ ) and (
+ value < self.upper
+ if self.bounds[1] == ")"
+ else value <= self.upper
+ )
+
+ def _get_discrete_step(self):
+ "Determine the “step” for this range, if it is a discrete one."
+
+ # See
+ # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE
+ # for the rationale
+
+ if isinstance(self.lower, int) or isinstance(self.upper, int):
+ return 1
+ elif isinstance(self.lower, datetime) or isinstance(
+ self.upper, datetime
+ ):
+ # This is required, because a `isinstance(datetime.now(), date)`
+ # is True
+ return None
+ elif isinstance(self.lower, date) or isinstance(self.upper, date):
+ return timedelta(days=1)
+ else:
+ return None
+
+ def contained_by(self, other: Range) -> bool:
+ "Determine whether this range is a contained by `other`."
+
+ # Any range contains the empty one
+ if self.empty:
+ return True
+
+ # An empty range does not contain any range except the empty one
+ if other.empty:
+ return False
+
+ olower = other.lower
+ oupper = other.upper
+
+ # A bilateral unbound range contains any other range
+ if olower is oupper is None:
+ return True
+
+ slower = self.lower
+ supper = self.upper
+
+ # A lower-bound range cannot contain a lower-unbound range
+ if slower is None and olower is not None:
+ return False
+
+ # Likewise on the right side
+ if supper is None and oupper is not None:
+ return False
+
+ slower_inc = self.bounds[0] == "["
+ supper_inc = self.bounds[1] == "]"
+ olower_inc = other.bounds[0] == "["
+ oupper_inc = other.bounds[1] == "]"
+
+ # Check the lower end
+ step = -1
+ if slower is not None and olower is not None:
+ lside = olower < slower
+ if not lside:
+ if not slower_inc or olower_inc:
+ lside = olower == slower
+ if not lside:
+ # Cover (1,x] vs [2,x) and (0,x] vs [1,x)
+ if not slower_inc and olower_inc and slower < olower:
+ step = self._get_discrete_step()
+ if step is not None:
+ lside = olower == (slower + step)
+ elif slower_inc and not olower_inc and slower > olower:
+ step = self._get_discrete_step()
+ if step is not None:
+ lside = (olower + step) == slower
+ if not lside:
+ return False
+
+ # Lower end already considered, an upper-unbound range surely contains
+ # this
+ if oupper is None:
+ return True
+
+ # Check the upper end
+ uside = oupper > supper
+ if not uside:
+ if not supper_inc or oupper_inc:
+ uside = oupper == supper
+ if not uside:
+ # Cover (x,2] vs [x,3) and (x,1] vs [x,2)
+ if supper_inc and not oupper_inc and supper < oupper:
+ if step == -1:
+ step = self._get_discrete_step()
+ if step is not None:
+ uside = oupper == (supper + step)
+ elif not supper_inc and oupper_inc and supper > oupper:
+ if step == -1:
+ step = self._get_discrete_step()
+ if step is not None:
+ uside = (oupper + step) == supper
+ return uside
+
+ def contains(self, value: Union[_T, Range]) -> bool:
+ "Determine whether this range contains `value`."
+
+ if isinstance(value, Range):
+ return value.contained_by(self)
+ else:
+ return self._contains_value(value)
+
+ def overlaps(self, other):
+ """Boolean expression. Returns true if the column overlaps
+ (has points in common with) the right hand operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def strictly_left_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ left of the right hand operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ __lshift__ = strictly_left_of
+
+ def strictly_right_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ right of the right hand operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ __rshift__ = strictly_right_of
+
+ def not_extend_right_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend right of the range in the operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def not_extend_left_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend left of the range in the operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def adjacent_to(self, other):
+ """Boolean expression. Returns true if the range in the column
+ is adjacent to the range in the operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def __add__(self, other):
+ """Range expression. Returns the union of the two ranges.
+ Will raise an exception if the resulting range is not
+ contiguous.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def __str__(self):
+ return self._stringify()
+
+ def _stringify(self):
+ if self.empty:
+ return "empty"
+
+ l, r = self.lower, self.upper
+ l = "" if l is None else l
+ r = "" if r is None else r
+
+ b0, b1 = self.bounds
+
+ return f"{b0}{l},{r}{b1}"
+
class AbstractRange(sqltypes.TypeEngine):
"""
@@ -93,6 +295,8 @@ class AbstractRange(sqltypes.TypeEngine):
render_bind_cast = True
+ __abstract__ = True
+
def adapt(self, impltype):
"""dynamically adapt a range type to an abstract impl.
@@ -202,6 +406,8 @@ class AbstractRangeImpl(AbstractRange):
class AbstractMultiRange(AbstractRange):
"""base for PostgreSQL MULTIRANGE types"""
+ __abstract__ = True
+
class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange):
"""marker for AbstractRange that will apply a subclass-specific
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 91eada9a8..83cea8f15 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -20,6 +20,7 @@ from sqlalchemy import func
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal
+from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import null
from sqlalchemy import Numeric
@@ -66,6 +67,8 @@ from sqlalchemy.sql import operators
from sqlalchemy.sql import sqltypes
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
from sqlalchemy.testing.assertions import AssertsCompiledSQL
@@ -3846,7 +3849,194 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
)
-class _RangeTypeRoundTrip(fixtures.TablesTest):
+class _RangeComparisonFixtures:
+ def _data_str(self):
+ """return string form of a sample range"""
+ raise NotImplementedError()
+
+ def _data_obj(self):
+ """return Range form of the same range"""
+ raise NotImplementedError()
+
+ def _step_value_up(self, value):
+ """given a value, return a step up
+
+ this is a value that given the lower end of the sample range,
+ would be less than the upper value of the range
+
+ """
+ raise NotImplementedError()
+
+ def _step_value_down(self, value):
+ """given a value, return a step down
+
+ this is a value that given the upper end of the sample range,
+ would be greater than the lower value of the range
+
+ """
+ raise NotImplementedError()
+
+ def _value_values(self):
+ """Return a series of values related to the base range
+
+ le = left equal
+ ll = lower than left
+ re = right equal
+ rh = higher than right
+ il = inside lower
+ ih = inside higher
+
+ """
+ spec = self._data_obj()
+
+ le, re_ = spec.lower, spec.upper
+
+ ll = self._step_value_down(le)
+ il = self._step_value_up(le)
+ rh = self._step_value_up(re_)
+ ih = self._step_value_down(re_)
+
+ return {"le": le, "re_": re_, "ll": ll, "il": il, "rh": rh, "ih": ih}
+
+ @testing.fixture(
+ params=[
+ lambda **kw: Range(empty=True),
+ lambda **kw: Range(bounds="[)"),
+ lambda le, **kw: Range(upper=le, bounds="[)"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[]"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="(]"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="()"),
+ lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="[)"),
+ lambda il, ih, **kw: Range(lower=il, upper=ih, bounds="[)"),
+ lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="(]"),
+ lambda ll, rh, **kw: Range(lower=ll, upper=rh, bounds="[)"),
+ ]
+ )
+ def contains_range_obj_combinations(self, request):
+ """ranges that are used for range contains() contained_by() tests"""
+ data = self._value_values()
+
+ range_ = request.param(**data)
+ yield range_
+
+ @testing.fixture(
+ params=[
+ lambda l, r: Range(empty=True),
+ lambda l, r: Range(bounds="()"),
+ lambda l, r: Range(upper=r, bounds="(]"),
+ lambda l, r: Range(lower=l, bounds="[)"),
+ lambda l, r: Range(lower=l, upper=r, bounds="[)"),
+ lambda l, r: Range(lower=l, upper=r, bounds="[]"),
+ lambda l, r: Range(lower=l, upper=r, bounds="(]"),
+ lambda l, r: Range(lower=l, upper=r, bounds="()"),
+ ]
+ )
+ def bounds_obj_combinations(self, request):
+ """sample ranges used for value and range contains()/contained_by()
+ tests"""
+
+ obj = self._data_obj()
+ l, r = obj.lower, obj.upper
+
+ template = request.param
+ value = template(l=l, r=r)
+ yield value
+
+ @testing.fixture(params=["ll", "le", "il", "ih", "re_", "rh"])
+ def value_combinations(self, request):
+ """sample values used for value contains() tests"""
+ data = self._value_values()
+ return data[request.param]
+
+ def test_basic_py_sanity(self):
+ values = self._value_values()
+
+ range_ = self._data_obj()
+
+ is_true(range_.contains(Range(lower=values["il"], upper=values["ih"])))
+
+ is_true(
+ range_.contained_by(Range(lower=values["ll"], upper=values["rh"]))
+ )
+
+ is_true(range_.contains(values["il"]))
+
+ is_false(
+ range_.contains(Range(lower=values["ll"], upper=values["ih"]))
+ )
+
+ is_false(range_.contains(values["rh"]))
+
+ def test_contains_value(
+ self, connection, bounds_obj_combinations, value_combinations
+ ):
+ range_ = bounds_obj_combinations
+ range_typ = self._col_str
+
+ strvalue = range_._stringify()
+
+ v = value_combinations
+ RANGE = self._col_type
+
+ q = select(
+ literal_column(f"'{strvalue}'::{range_typ}", RANGE).label("r1"),
+ cast(range_, RANGE).label("r2"),
+ )
+ literal_range, cast_range = connection.execute(q).first()
+ eq_(literal_range, cast_range)
+
+ q = select(
+ cast(range_, RANGE),
+ cast(range_, RANGE).contains(v),
+ )
+ r, expected = connection.execute(q).first()
+ eq_(r.contains(v), expected)
+
+ def test_contains_range(
+ self,
+ connection,
+ bounds_obj_combinations,
+ contains_range_obj_combinations,
+ ):
+ r1repr = contains_range_obj_combinations._stringify()
+ r2repr = bounds_obj_combinations._stringify()
+
+ RANGE = self._col_type
+ range_typ = self._col_str
+
+ q = select(
+ cast(contains_range_obj_combinations, RANGE).label("r1"),
+ cast(bounds_obj_combinations, RANGE).label("r2"),
+ cast(contains_range_obj_combinations, RANGE).contains(
+ bounds_obj_combinations
+ ),
+ cast(contains_range_obj_combinations, RANGE).contained_by(
+ bounds_obj_combinations
+ ),
+ )
+ validate_q = select(
+ literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+ literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+ literal_column(
+ f"'{r1repr}'::{range_typ} @> '{r2repr}'::{range_typ}"
+ ),
+ literal_column(
+ f"'{r1repr}'::{range_typ} <@ '{r2repr}'::{range_typ}"
+ ),
+ )
+ orig_row = connection.execute(q).first()
+ validate_row = connection.execute(validate_q).first()
+ eq_(orig_row, validate_row)
+
+ r1, r2, contains, contained = orig_row
+ eq_(r1.contains(r2), contains)
+ eq_(r1.contained_by(r2), contained)
+ eq_(r2.contains(r1), contained)
+
+
+class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
__requires__ = ("range_types",)
__backend__ = True
@@ -3861,6 +4051,9 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
)
cls.col = table.c.range
+ def test_stringify(self):
+ eq_(str(self._data_obj()), self._data_str())
+
def test_auto_cast_back_to_type(self, connection):
"""test that a straight pass of the range type without any context
will send appropriate casting info so that the driver can round
@@ -3995,10 +4188,16 @@ class _Int4RangeTests:
_col_str = "INT4RANGE"
def _data_str(self):
- return "[1,2)"
+ return "[1,4)"
def _data_obj(self):
- return Range(1, 2)
+ return Range(1, 4)
+
+ def _step_value_up(self, value):
+ return value + 1
+
+ def _step_value_down(self, value):
+ return value - 1
class _Int8RangeTests:
@@ -4007,10 +4206,16 @@ class _Int8RangeTests:
_col_str = "INT8RANGE"
def _data_str(self):
- return "[9223372036854775806,9223372036854775807)"
+ return "[9223372036854775306,9223372036854775800)"
def _data_obj(self):
- return Range(9223372036854775806, 9223372036854775807)
+ return Range(9223372036854775306, 9223372036854775800)
+
+ def _step_value_up(self, value):
+ return value + 5
+
+ def _step_value_down(self, value):
+ return value - 5
class _NumRangeTests:
@@ -4019,10 +4224,16 @@ class _NumRangeTests:
_col_str = "NUMRANGE"
def _data_str(self):
- return "[1.0,2.0)"
+ return "[1.0,9.0)"
def _data_obj(self):
- return Range(decimal.Decimal("1.0"), decimal.Decimal("2.0"))
+ return Range(decimal.Decimal("1.0"), decimal.Decimal("9.0"))
+
+ def _step_value_up(self, value):
+ return value + decimal.Decimal("1.8")
+
+ def _step_value_down(self, value):
+ return value - decimal.Decimal("1.8")
class _DateRangeTests:
@@ -4031,10 +4242,16 @@ class _DateRangeTests:
_col_str = "DATERANGE"
def _data_str(self):
- return "[2013-03-23,2013-03-24)"
+ return "[2013-03-23,2013-03-30)"
def _data_obj(self):
- return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 24))
+ return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 30))
+
+ def _step_value_up(self, value):
+ return value + datetime.timedelta(days=1)
+
+ def _step_value_down(self, value):
+ return value - datetime.timedelta(days=1)
class _DateTimeRangeTests:
@@ -4043,38 +4260,47 @@ class _DateTimeRangeTests:
_col_str = "TSRANGE"
def _data_str(self):
- return "[2013-03-23 14:30,2013-03-23 23:30)"
+ return "[2013-03-23 14:30:00,2013-03-30 23:30:00)"
def _data_obj(self):
return Range(
datetime.datetime(2013, 3, 23, 14, 30),
- datetime.datetime(2013, 3, 23, 23, 30),
+ datetime.datetime(2013, 3, 30, 23, 30),
)
+ def _step_value_up(self, value):
+ return value + datetime.timedelta(days=1)
+
+ def _step_value_down(self, value):
+ return value - datetime.timedelta(days=1)
+
class _DateTimeTZRangeTests:
_col_type = TSTZRANGE
_col_str = "TSTZRANGE"
- # make sure we use one, steady timestamp with timezone pair
- # for all parts of all these tests
- _tstzs = None
-
def tstzs(self):
- if self._tstzs is None:
- with testing.db.connect() as connection:
- lower = connection.scalar(func.current_timestamp().select())
- upper = lower + datetime.timedelta(1)
- self._tstzs = (lower, upper)
- return self._tstzs
+ tz = datetime.timezone(-datetime.timedelta(hours=5, minutes=30))
+
+ return (
+ datetime.datetime(2013, 3, 23, 14, 30, tzinfo=tz),
+ datetime.datetime(2013, 3, 30, 23, 30, tzinfo=tz),
+ )
def _data_str(self):
- return "[%s,%s)" % self.tstzs()
+ l, r = self.tstzs()
+ return f"[{l},{r})"
def _data_obj(self):
return Range(*self.tstzs())
+ def _step_value_up(self, value):
+ return value + datetime.timedelta(days=1)
+
+ def _step_value_down(self, value):
+ return value - datetime.timedelta(days=1)
+
class Int4RangeCompilationTest(_Int4RangeTests, _RangeTypeCompilation):
pass