summaryrefslogtreecommitdiff
path: root/test/dialect/postgresql
diff options
context:
space:
mode:
authorLele Gaifax <lele@metapensiero.it>2022-11-02 08:33:41 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-11-03 09:30:38 -0400
commite8124b29b07fd17ab2f2b6892534dcc4b0797ab4 (patch)
tree5e8436a3f457b6bc2b7fb53c659afd92f209bc47 /test/dialect/postgresql
parent66f3533de86506327c753c1ea80b121692535745 (diff)
downloadsqlalchemy-e8124b29b07fd17ab2f2b6892534dcc4b0797ab4.tar.gz
Implement contains_value(), issubset() and issuperset() on PG Range
Added new methods :meth:`_postgresql.Range.contains` and :meth:`_postgresql.Range.contained_by` to the new :class:`.Range` data object, which mirror the behavior of the PostgreSQL ``@>`` and ``<@`` operators, as well as the :meth:`_postgresql.AbstractRange.comparator_factory.contains` and :meth:`_postgresql.AbstractRange.comparator_factory.contained_by` SQL operator methods. Pull request courtesy Lele Gaifax. Fixes: #8706 Closes: #8707 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8707 Pull-request-sha: 3a74a0d93e63032ebee02992977498c717a077ff Change-Id: Ief81ca5c31448640b26dfbc3defd4dde1d51e366
Diffstat (limited to 'test/dialect/postgresql')
-rw-r--r--test/dialect/postgresql/test_types.py270
1 files changed, 248 insertions, 22 deletions
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 91eada9a8..83cea8f15 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -20,6 +20,7 @@ from sqlalchemy import func
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal
+from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import null
from sqlalchemy import Numeric
@@ -66,6 +67,8 @@ from sqlalchemy.sql import operators
from sqlalchemy.sql import sqltypes
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
from sqlalchemy.testing.assertions import AssertsCompiledSQL
@@ -3846,7 +3849,194 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
)
-class _RangeTypeRoundTrip(fixtures.TablesTest):
+class _RangeComparisonFixtures:
+ def _data_str(self):
+ """return string form of a sample range"""
+ raise NotImplementedError()
+
+ def _data_obj(self):
+ """return Range form of the same range"""
+ raise NotImplementedError()
+
+ def _step_value_up(self, value):
+ """given a value, return a step up
+
+ this is a value that given the lower end of the sample range,
+ would be less than the upper value of the range
+
+ """
+ raise NotImplementedError()
+
+ def _step_value_down(self, value):
+ """given a value, return a step down
+
+ this is a value that given the upper end of the sample range,
+ would be greater than the lower value of the range
+
+ """
+ raise NotImplementedError()
+
+ def _value_values(self):
+ """Return a series of values related to the base range
+
+ le = left equal
+ ll = lower than left
+ re = right equal
+ rh = higher than right
+ il = inside lower
+ ih = inside higher
+
+ """
+ spec = self._data_obj()
+
+ le, re_ = spec.lower, spec.upper
+
+ ll = self._step_value_down(le)
+ il = self._step_value_up(le)
+ rh = self._step_value_up(re_)
+ ih = self._step_value_down(re_)
+
+ return {"le": le, "re_": re_, "ll": ll, "il": il, "rh": rh, "ih": ih}
+
+ @testing.fixture(
+ params=[
+ lambda **kw: Range(empty=True),
+ lambda **kw: Range(bounds="[)"),
+ lambda le, **kw: Range(upper=le, bounds="[)"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[]"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="(]"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="()"),
+ lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="[)"),
+ lambda il, ih, **kw: Range(lower=il, upper=ih, bounds="[)"),
+ lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="(]"),
+ lambda ll, rh, **kw: Range(lower=ll, upper=rh, bounds="[)"),
+ ]
+ )
+ def contains_range_obj_combinations(self, request):
+ """ranges that are used for range contains() contained_by() tests"""
+ data = self._value_values()
+
+ range_ = request.param(**data)
+ yield range_
+
+ @testing.fixture(
+ params=[
+ lambda l, r: Range(empty=True),
+ lambda l, r: Range(bounds="()"),
+ lambda l, r: Range(upper=r, bounds="(]"),
+ lambda l, r: Range(lower=l, bounds="[)"),
+ lambda l, r: Range(lower=l, upper=r, bounds="[)"),
+ lambda l, r: Range(lower=l, upper=r, bounds="[]"),
+ lambda l, r: Range(lower=l, upper=r, bounds="(]"),
+ lambda l, r: Range(lower=l, upper=r, bounds="()"),
+ ]
+ )
+ def bounds_obj_combinations(self, request):
+ """sample ranges used for value and range contains()/contained_by()
+ tests"""
+
+ obj = self._data_obj()
+ l, r = obj.lower, obj.upper
+
+ template = request.param
+ value = template(l=l, r=r)
+ yield value
+
+ @testing.fixture(params=["ll", "le", "il", "ih", "re_", "rh"])
+ def value_combinations(self, request):
+ """sample values used for value contains() tests"""
+ data = self._value_values()
+ return data[request.param]
+
+ def test_basic_py_sanity(self):
+ values = self._value_values()
+
+ range_ = self._data_obj()
+
+ is_true(range_.contains(Range(lower=values["il"], upper=values["ih"])))
+
+ is_true(
+ range_.contained_by(Range(lower=values["ll"], upper=values["rh"]))
+ )
+
+ is_true(range_.contains(values["il"]))
+
+ is_false(
+ range_.contains(Range(lower=values["ll"], upper=values["ih"]))
+ )
+
+ is_false(range_.contains(values["rh"]))
+
+ def test_contains_value(
+ self, connection, bounds_obj_combinations, value_combinations
+ ):
+ range_ = bounds_obj_combinations
+ range_typ = self._col_str
+
+ strvalue = range_._stringify()
+
+ v = value_combinations
+ RANGE = self._col_type
+
+ q = select(
+ literal_column(f"'{strvalue}'::{range_typ}", RANGE).label("r1"),
+ cast(range_, RANGE).label("r2"),
+ )
+ literal_range, cast_range = connection.execute(q).first()
+ eq_(literal_range, cast_range)
+
+ q = select(
+ cast(range_, RANGE),
+ cast(range_, RANGE).contains(v),
+ )
+ r, expected = connection.execute(q).first()
+ eq_(r.contains(v), expected)
+
+ def test_contains_range(
+ self,
+ connection,
+ bounds_obj_combinations,
+ contains_range_obj_combinations,
+ ):
+ r1repr = contains_range_obj_combinations._stringify()
+ r2repr = bounds_obj_combinations._stringify()
+
+ RANGE = self._col_type
+ range_typ = self._col_str
+
+ q = select(
+ cast(contains_range_obj_combinations, RANGE).label("r1"),
+ cast(bounds_obj_combinations, RANGE).label("r2"),
+ cast(contains_range_obj_combinations, RANGE).contains(
+ bounds_obj_combinations
+ ),
+ cast(contains_range_obj_combinations, RANGE).contained_by(
+ bounds_obj_combinations
+ ),
+ )
+ validate_q = select(
+ literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+ literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+ literal_column(
+ f"'{r1repr}'::{range_typ} @> '{r2repr}'::{range_typ}"
+ ),
+ literal_column(
+ f"'{r1repr}'::{range_typ} <@ '{r2repr}'::{range_typ}"
+ ),
+ )
+ orig_row = connection.execute(q).first()
+ validate_row = connection.execute(validate_q).first()
+ eq_(orig_row, validate_row)
+
+ r1, r2, contains, contained = orig_row
+ eq_(r1.contains(r2), contains)
+ eq_(r1.contained_by(r2), contained)
+ eq_(r2.contains(r1), contained)
+
+
+class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
__requires__ = ("range_types",)
__backend__ = True
@@ -3861,6 +4051,9 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
)
cls.col = table.c.range
+ def test_stringify(self):
+ eq_(str(self._data_obj()), self._data_str())
+
def test_auto_cast_back_to_type(self, connection):
"""test that a straight pass of the range type without any context
will send appropriate casting info so that the driver can round
@@ -3995,10 +4188,16 @@ class _Int4RangeTests:
_col_str = "INT4RANGE"
def _data_str(self):
- return "[1,2)"
+ return "[1,4)"
def _data_obj(self):
- return Range(1, 2)
+ return Range(1, 4)
+
+ def _step_value_up(self, value):
+ return value + 1
+
+ def _step_value_down(self, value):
+ return value - 1
class _Int8RangeTests:
@@ -4007,10 +4206,16 @@ class _Int8RangeTests:
_col_str = "INT8RANGE"
def _data_str(self):
- return "[9223372036854775806,9223372036854775807)"
+ return "[9223372036854775306,9223372036854775800)"
def _data_obj(self):
- return Range(9223372036854775806, 9223372036854775807)
+ return Range(9223372036854775306, 9223372036854775800)
+
+ def _step_value_up(self, value):
+ return value + 5
+
+ def _step_value_down(self, value):
+ return value - 5
class _NumRangeTests:
@@ -4019,10 +4224,16 @@ class _NumRangeTests:
_col_str = "NUMRANGE"
def _data_str(self):
- return "[1.0,2.0)"
+ return "[1.0,9.0)"
def _data_obj(self):
- return Range(decimal.Decimal("1.0"), decimal.Decimal("2.0"))
+ return Range(decimal.Decimal("1.0"), decimal.Decimal("9.0"))
+
+ def _step_value_up(self, value):
+ return value + decimal.Decimal("1.8")
+
+ def _step_value_down(self, value):
+ return value - decimal.Decimal("1.8")
class _DateRangeTests:
@@ -4031,10 +4242,16 @@ class _DateRangeTests:
_col_str = "DATERANGE"
def _data_str(self):
- return "[2013-03-23,2013-03-24)"
+ return "[2013-03-23,2013-03-30)"
def _data_obj(self):
- return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 24))
+ return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 30))
+
+ def _step_value_up(self, value):
+ return value + datetime.timedelta(days=1)
+
+ def _step_value_down(self, value):
+ return value - datetime.timedelta(days=1)
class _DateTimeRangeTests:
@@ -4043,38 +4260,47 @@ class _DateTimeRangeTests:
_col_str = "TSRANGE"
def _data_str(self):
- return "[2013-03-23 14:30,2013-03-23 23:30)"
+ return "[2013-03-23 14:30:00,2013-03-30 23:30:00)"
def _data_obj(self):
return Range(
datetime.datetime(2013, 3, 23, 14, 30),
- datetime.datetime(2013, 3, 23, 23, 30),
+ datetime.datetime(2013, 3, 30, 23, 30),
)
+ def _step_value_up(self, value):
+ return value + datetime.timedelta(days=1)
+
+ def _step_value_down(self, value):
+ return value - datetime.timedelta(days=1)
+
class _DateTimeTZRangeTests:
_col_type = TSTZRANGE
_col_str = "TSTZRANGE"
- # make sure we use one, steady timestamp with timezone pair
- # for all parts of all these tests
- _tstzs = None
-
def tstzs(self):
- if self._tstzs is None:
- with testing.db.connect() as connection:
- lower = connection.scalar(func.current_timestamp().select())
- upper = lower + datetime.timedelta(1)
- self._tstzs = (lower, upper)
- return self._tstzs
+ tz = datetime.timezone(-datetime.timedelta(hours=5, minutes=30))
+
+ return (
+ datetime.datetime(2013, 3, 23, 14, 30, tzinfo=tz),
+ datetime.datetime(2013, 3, 30, 23, 30, tzinfo=tz),
+ )
def _data_str(self):
- return "[%s,%s)" % self.tstzs()
+ l, r = self.tstzs()
+ return f"[{l},{r})"
def _data_obj(self):
return Range(*self.tstzs())
+ def _step_value_up(self, value):
+ return value + datetime.timedelta(days=1)
+
+ def _step_value_down(self, value):
+ return value - datetime.timedelta(days=1)
+
class Int4RangeCompilationTest(_Int4RangeTests, _RangeTypeCompilation):
pass