diff options
author | Lele Gaifax <lele@metapensiero.it> | 2022-11-02 08:33:41 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-11-03 09:30:38 -0400 |
commit | e8124b29b07fd17ab2f2b6892534dcc4b0797ab4 (patch) | |
tree | 5e8436a3f457b6bc2b7fb53c659afd92f209bc47 /test/dialect/postgresql | |
parent | 66f3533de86506327c753c1ea80b121692535745 (diff) | |
download | sqlalchemy-e8124b29b07fd17ab2f2b6892534dcc4b0797ab4.tar.gz |
Implement contains_value(), issubset() and issuperset() on PG Range
Added new methods :meth:`_postgresql.Range.contains` and
:meth:`_postgresql.Range.contained_by` to the new :class:`.Range` data
object, which mirror the behavior of the PostgreSQL ``@>`` and ``<@``
operators, as well as the
:meth:`_postgresql.AbstractRange.comparator_factory.contains` and
:meth:`_postgresql.AbstractRange.comparator_factory.contained_by` SQL
operator methods. Pull request courtesy Lele Gaifax.
Fixes: #8706
Closes: #8707
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8707
Pull-request-sha: 3a74a0d93e63032ebee02992977498c717a077ff
Change-Id: Ief81ca5c31448640b26dfbc3defd4dde1d51e366
Diffstat (limited to 'test/dialect/postgresql')
-rw-r--r-- | test/dialect/postgresql/test_types.py | 270 |
1 files changed, 248 insertions, 22 deletions
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 91eada9a8..83cea8f15 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -20,6 +20,7 @@ from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal +from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import null from sqlalchemy import Numeric @@ -66,6 +67,8 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message from sqlalchemy.testing.assertions import AssertsCompiledSQL @@ -3846,7 +3849,194 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): ) -class _RangeTypeRoundTrip(fixtures.TablesTest): +class _RangeComparisonFixtures: + def _data_str(self): + """return string form of a sample range""" + raise NotImplementedError() + + def _data_obj(self): + """return Range form of the same range""" + raise NotImplementedError() + + def _step_value_up(self, value): + """given a value, return a step up + + this is a value that given the lower end of the sample range, + would be less than the upper value of the range + + """ + raise NotImplementedError() + + def _step_value_down(self, value): + """given a value, return a step down + + this is a value that given the upper end of the sample range, + would be greater than the lower value of the range + + """ + raise NotImplementedError() + + def _value_values(self): + """Return a series of values related to the base range + + le = left equal + ll = lower than left + re = right equal + rh = higher than right + il = inside lower + ih = inside higher + + """ + spec = self._data_obj() + + le, re_ = spec.lower, spec.upper + + ll = self._step_value_down(le) + il = self._step_value_up(le) + rh = self._step_value_up(re_) + ih = self._step_value_down(re_) + + return {"le": le, "re_": re_, "ll": ll, "il": il, "rh": rh, "ih": ih} + + @testing.fixture( + params=[ + lambda **kw: Range(empty=True), + lambda **kw: Range(bounds="[)"), + lambda le, **kw: Range(upper=le, bounds="[)"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[]"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="(]"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="()"), + lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="[)"), + lambda il, ih, **kw: Range(lower=il, upper=ih, bounds="[)"), + lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="(]"), + lambda ll, rh, **kw: Range(lower=ll, upper=rh, bounds="[)"), + ] + ) + def contains_range_obj_combinations(self, request): + """ranges that are used for range contains() contained_by() tests""" + data = self._value_values() + + range_ = request.param(**data) + yield range_ + + @testing.fixture( + params=[ + lambda l, r: Range(empty=True), + lambda l, r: Range(bounds="()"), + lambda l, r: Range(upper=r, bounds="(]"), + lambda l, r: Range(lower=l, bounds="[)"), + lambda l, r: Range(lower=l, upper=r, bounds="[)"), + lambda l, r: Range(lower=l, upper=r, bounds="[]"), + lambda l, r: Range(lower=l, upper=r, bounds="(]"), + lambda l, r: Range(lower=l, upper=r, bounds="()"), + ] + ) + def bounds_obj_combinations(self, request): + """sample ranges used for value and range contains()/contained_by() + tests""" + + obj = self._data_obj() + l, r = obj.lower, obj.upper + + template = request.param + value = template(l=l, r=r) + yield value + + @testing.fixture(params=["ll", "le", "il", "ih", "re_", "rh"]) + def value_combinations(self, request): + """sample values used for value contains() tests""" + data = self._value_values() + return data[request.param] + + def test_basic_py_sanity(self): + values = self._value_values() + + range_ = self._data_obj() + + is_true(range_.contains(Range(lower=values["il"], upper=values["ih"]))) + + is_true( + range_.contained_by(Range(lower=values["ll"], upper=values["rh"])) + ) + + is_true(range_.contains(values["il"])) + + is_false( + range_.contains(Range(lower=values["ll"], upper=values["ih"])) + ) + + is_false(range_.contains(values["rh"])) + + def test_contains_value( + self, connection, bounds_obj_combinations, value_combinations + ): + range_ = bounds_obj_combinations + range_typ = self._col_str + + strvalue = range_._stringify() + + v = value_combinations + RANGE = self._col_type + + q = select( + literal_column(f"'{strvalue}'::{range_typ}", RANGE).label("r1"), + cast(range_, RANGE).label("r2"), + ) + literal_range, cast_range = connection.execute(q).first() + eq_(literal_range, cast_range) + + q = select( + cast(range_, RANGE), + cast(range_, RANGE).contains(v), + ) + r, expected = connection.execute(q).first() + eq_(r.contains(v), expected) + + def test_contains_range( + self, + connection, + bounds_obj_combinations, + contains_range_obj_combinations, + ): + r1repr = contains_range_obj_combinations._stringify() + r2repr = bounds_obj_combinations._stringify() + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(contains_range_obj_combinations, RANGE).label("r1"), + cast(bounds_obj_combinations, RANGE).label("r2"), + cast(contains_range_obj_combinations, RANGE).contains( + bounds_obj_combinations + ), + cast(contains_range_obj_combinations, RANGE).contained_by( + bounds_obj_combinations + ), + ) + validate_q = select( + literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), + literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), + literal_column( + f"'{r1repr}'::{range_typ} @> '{r2repr}'::{range_typ}" + ), + literal_column( + f"'{r1repr}'::{range_typ} <@ '{r2repr}'::{range_typ}" + ), + ) + orig_row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(orig_row, validate_row) + + r1, r2, contains, contained = orig_row + eq_(r1.contains(r2), contains) + eq_(r1.contained_by(r2), contained) + eq_(r2.contains(r1), contained) + + +class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest): __requires__ = ("range_types",) __backend__ = True @@ -3861,6 +4051,9 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): ) cls.col = table.c.range + def test_stringify(self): + eq_(str(self._data_obj()), self._data_str()) + def test_auto_cast_back_to_type(self, connection): """test that a straight pass of the range type without any context will send appropriate casting info so that the driver can round @@ -3995,10 +4188,16 @@ class _Int4RangeTests: _col_str = "INT4RANGE" def _data_str(self): - return "[1,2)" + return "[1,4)" def _data_obj(self): - return Range(1, 2) + return Range(1, 4) + + def _step_value_up(self, value): + return value + 1 + + def _step_value_down(self, value): + return value - 1 class _Int8RangeTests: @@ -4007,10 +4206,16 @@ class _Int8RangeTests: _col_str = "INT8RANGE" def _data_str(self): - return "[9223372036854775806,9223372036854775807)" + return "[9223372036854775306,9223372036854775800)" def _data_obj(self): - return Range(9223372036854775806, 9223372036854775807) + return Range(9223372036854775306, 9223372036854775800) + + def _step_value_up(self, value): + return value + 5 + + def _step_value_down(self, value): + return value - 5 class _NumRangeTests: @@ -4019,10 +4224,16 @@ class _NumRangeTests: _col_str = "NUMRANGE" def _data_str(self): - return "[1.0,2.0)" + return "[1.0,9.0)" def _data_obj(self): - return Range(decimal.Decimal("1.0"), decimal.Decimal("2.0")) + return Range(decimal.Decimal("1.0"), decimal.Decimal("9.0")) + + def _step_value_up(self, value): + return value + decimal.Decimal("1.8") + + def _step_value_down(self, value): + return value - decimal.Decimal("1.8") class _DateRangeTests: @@ -4031,10 +4242,16 @@ class _DateRangeTests: _col_str = "DATERANGE" def _data_str(self): - return "[2013-03-23,2013-03-24)" + return "[2013-03-23,2013-03-30)" def _data_obj(self): - return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)) + return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 30)) + + def _step_value_up(self, value): + return value + datetime.timedelta(days=1) + + def _step_value_down(self, value): + return value - datetime.timedelta(days=1) class _DateTimeRangeTests: @@ -4043,38 +4260,47 @@ class _DateTimeRangeTests: _col_str = "TSRANGE" def _data_str(self): - return "[2013-03-23 14:30,2013-03-23 23:30)" + return "[2013-03-23 14:30:00,2013-03-30 23:30:00)" def _data_obj(self): return Range( datetime.datetime(2013, 3, 23, 14, 30), - datetime.datetime(2013, 3, 23, 23, 30), + datetime.datetime(2013, 3, 30, 23, 30), ) + def _step_value_up(self, value): + return value + datetime.timedelta(days=1) + + def _step_value_down(self, value): + return value - datetime.timedelta(days=1) + class _DateTimeTZRangeTests: _col_type = TSTZRANGE _col_str = "TSTZRANGE" - # make sure we use one, steady timestamp with timezone pair - # for all parts of all these tests - _tstzs = None - def tstzs(self): - if self._tstzs is None: - with testing.db.connect() as connection: - lower = connection.scalar(func.current_timestamp().select()) - upper = lower + datetime.timedelta(1) - self._tstzs = (lower, upper) - return self._tstzs + tz = datetime.timezone(-datetime.timedelta(hours=5, minutes=30)) + + return ( + datetime.datetime(2013, 3, 23, 14, 30, tzinfo=tz), + datetime.datetime(2013, 3, 30, 23, 30, tzinfo=tz), + ) def _data_str(self): - return "[%s,%s)" % self.tstzs() + l, r = self.tstzs() + return f"[{l},{r})" def _data_obj(self): return Range(*self.tstzs()) + def _step_value_up(self, value): + return value + datetime.timedelta(days=1) + + def _step_value_down(self, value): + return value - datetime.timedelta(days=1) + class Int4RangeCompilationTest(_Int4RangeTests, _RangeTypeCompilation): pass |