diff options
| author | zeeeeb <z3eee3b@gmail.com> | 2022-06-28 19:05:08 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-08-04 09:39:38 -0400 |
| commit | eeff036db61377b8159757e6cc2a2d83d85bf69e (patch) | |
| tree | b44ee342d06673a899d9b68d80f7130a8391bf24 /test/dialect/postgresql | |
| parent | 7c8572f004c0567482de98eb5697d8bb5e328b2d (diff) | |
| download | sqlalchemy-eeff036db61377b8159757e6cc2a2d83d85bf69e.tar.gz | |
fixes: #7156 - Adds support for PostgreSQL MultiRange type
This adds functionality for PostgreSQL MultiRange type, as discussed in Issue #7156.
As far as I can tell, only psycopg provides a [Multirange adaptation](https://www.psycopg.org/psycopg3/docs/basic/pgtypes.html#multirange-adaptation). Psycopg2 only supports a [Range adaptation/data type](https://www.psycopg.org/psycopg3/docs/basic/pgtypes.html#multirange-adaptation).
This pull request is:
- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
must include a complete example of the issue. one line code fixes without an
issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests. one line code fixes without tests will not be accepted.
- [x] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.
Closes: #7816
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7816
Pull-request-sha: 7e9e0c858dcdb58d4fcca24964ef8d58d1842d41
Change-Id: I345e0f58f534ac37709a7a4627b6de8ddd8fa89e
Diffstat (limited to 'test/dialect/postgresql')
| -rw-r--r-- | test/dialect/postgresql/test_dialect.py | 4 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 494 |
2 files changed, 494 insertions, 4 deletions
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 1ffd82ae4..9cbb0bca7 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -1207,7 +1207,7 @@ class MiscBackendTest( dbapi_conn.rollback() eq_(val, "off") - @testing.requires.psycopg_compatibility + @testing.requires.any_psycopg_compatibility def test_psycopg_non_standard_err(self): # note that psycopg2 is sometimes called psycopg2cffi # depending on platform @@ -1230,7 +1230,7 @@ class MiscBackendTest( assert isinstance(exception, exc.OperationalError) @testing.requires.no_coverage - @testing.requires.psycopg_compatibility + @testing.requires.any_psycopg_compatibility def test_notice_logging(self): log = logging.getLogger("sqlalchemy.dialects.postgresql") buf = logging.handlers.BufferingHandler(100) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 41bd1f5e7..f774300e6 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1,4 +1,5 @@ # coding: utf-8 +from collections import defaultdict import datetime import decimal from enum import Enum as _PY_Enum @@ -37,18 +38,24 @@ from sqlalchemy import Unicode from sqlalchemy import util from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import array +from sqlalchemy.dialects.postgresql import DATEMULTIRANGE from sqlalchemy.dialects.postgresql import DATERANGE from sqlalchemy.dialects.postgresql import DOMAIN from sqlalchemy.dialects.postgresql import ENUM from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import hstore +from sqlalchemy.dialects.postgresql import INT4MULTIRANGE from sqlalchemy.dialects.postgresql import INT4RANGE +from sqlalchemy.dialects.postgresql import INT8MULTIRANGE from sqlalchemy.dialects.postgresql import INT8RANGE from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import NamedType +from sqlalchemy.dialects.postgresql import NUMMULTIRANGE from sqlalchemy.dialects.postgresql import NUMRANGE +from sqlalchemy.dialects.postgresql import TSMULTIRANGE from sqlalchemy.dialects.postgresql import TSRANGE +from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE from sqlalchemy.dialects.postgresql import TSTZRANGE from sqlalchemy.exc import CompileError from sqlalchemy.orm import declarative_base @@ -2650,7 +2657,7 @@ class ArrayEnum(fixtures.TestBase): testing.combinations( sqltypes.ARRAY, postgresql.ARRAY, - (_ArrayOfEnum, testing.requires.psycopg_compatibility), + (_ArrayOfEnum, testing.requires.any_psycopg_compatibility), argnames="array_cls", )(fn) ) @@ -3701,7 +3708,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): class _RangeTypeRoundTrip(fixtures.TablesTest): - __requires__ = "range_types", "psycopg_compatibility" + __requires__ = "range_types", "any_psycopg_compatibility" __backend__ = True def extras(self): @@ -3934,6 +3941,489 @@ class DateTimeTZRangeRoundTripTest(_DateTimeTZRangeTests, _RangeTypeRoundTrip): pass +class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): + __dialect__ = "postgresql" + + # operator tests + + @classmethod + def setup_test_class(cls): + table = Table( + "data_table", + MetaData(), + Column("multirange", cls._col_type, primary_key=True), + ) + cls.col = table.c.multirange + + def _test_clause(self, colclause, expected, type_): + self.assert_compile(colclause, expected) + is_(colclause.type._type_affinity, type_._type_affinity) + + def test_where_equal(self): + self._test_clause( + self.col == self._data_str(), + "data_table.multirange = %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_where_not_equal(self): + self._test_clause( + self.col != self._data_str(), + "data_table.multirange <> %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_where_is_null(self): + self._test_clause( + self.col == None, + "data_table.multirange IS NULL", + sqltypes.BOOLEANTYPE, + ) + + def test_where_is_not_null(self): + self._test_clause( + self.col != None, + "data_table.multirange IS NOT NULL", + sqltypes.BOOLEANTYPE, + ) + + def test_where_less_than(self): + self._test_clause( + self.col < self._data_str(), + "data_table.multirange < %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_where_greater_than(self): + self._test_clause( + self.col > self._data_str(), + "data_table.multirange > %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_where_less_than_or_equal(self): + self._test_clause( + self.col <= self._data_str(), + "data_table.multirange <= %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_where_greater_than_or_equal(self): + self._test_clause( + self.col >= self._data_str(), + "data_table.multirange >= %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_contains(self): + self._test_clause( + self.col.contains(self._data_str()), + "data_table.multirange @> %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_contained_by(self): + self._test_clause( + self.col.contained_by(self._data_str()), + "data_table.multirange <@ %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_overlaps(self): + self._test_clause( + self.col.overlaps(self._data_str()), + "data_table.multirange && %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_strictly_left_of(self): + self._test_clause( + self.col << self._data_str(), + "data_table.multirange << %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + self._test_clause( + self.col.strictly_left_of(self._data_str()), + "data_table.multirange << %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_strictly_right_of(self): + self._test_clause( + self.col >> self._data_str(), + "data_table.multirange >> %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + self._test_clause( + self.col.strictly_right_of(self._data_str()), + "data_table.multirange >> %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_not_extend_right_of(self): + self._test_clause( + self.col.not_extend_right_of(self._data_str()), + "data_table.multirange &< %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_not_extend_left_of(self): + self._test_clause( + self.col.not_extend_left_of(self._data_str()), + "data_table.multirange &> %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_adjacent_to(self): + self._test_clause( + self.col.adjacent_to(self._data_str()), + "data_table.multirange -|- %(multirange_1)s", + sqltypes.BOOLEANTYPE, + ) + + def test_union(self): + self._test_clause( + self.col + self.col, + "data_table.multirange + data_table.multirange", + self.col.type, + ) + + def test_intersection(self): + self._test_clause( + self.col * self.col, + "data_table.multirange * data_table.multirange", + self.col.type, + ) + + def test_different(self): + self._test_clause( + self.col - self.col, + "data_table.multirange - data_table.multirange", + self.col.type, + ) + + +class _MultiRangeTypeRoundTrip(fixtures.TablesTest): + __requires__ = "range_types", "psycopg_only_compatibility" + __backend__ = True + + def extras(self): + # done this way so we don't get ImportErrors with + # older psycopg2 versions. + if testing.against("postgresql+psycopg"): + from psycopg.types.range import Range + from psycopg.types.multirange import Multirange + + class psycopg_extras: + def __init__(self): + self.data = defaultdict( + lambda: Range, Multirange=Multirange + ) + + def __getattr__(self, name): + return self.data[name] + + extras = psycopg_extras() + else: + assert False, "Unsupported MultiRange Dialect" + return extras + + @classmethod + def define_tables(cls, metadata): + # no reason ranges shouldn't be primary keys, + # so lets just use them as such + table = Table( + "data_table", + metadata, + Column("range", cls._col_type, primary_key=True), + ) + cls.col = table.c.range + + def test_actual_type(self): + eq_(str(self._col_type()), self._col_str) + + def test_reflect(self, connection): + from sqlalchemy import inspect + + insp = inspect(connection) + cols = insp.get_columns("data_table") + assert isinstance(cols[0]["type"], self._col_type) + + def _assert_data(self, conn): + data = conn.execute(select(self.tables.data_table.c.range)).fetchall() + eq_(data, [(self._data_obj(),)]) + + def test_insert_obj(self, connection): + connection.execute( + self.tables.data_table.insert(), {"range": self._data_obj()} + ) + self._assert_data(connection) + + def test_insert_text(self, connection): + connection.execute( + self.tables.data_table.insert(), {"range": self._data_str()} + ) + self._assert_data(connection) + + def test_union_result(self, connection): + # insert + connection.execute( + self.tables.data_table.insert(), {"range": self._data_str()} + ) + # select + range_ = self.tables.data_table.c.range + data = connection.execute(select(range_ + range_)).fetchall() + eq_(data, [(self._data_obj(),)]) + + def test_intersection_result(self, connection): + # insert + connection.execute( + self.tables.data_table.insert(), {"range": self._data_str()} + ) + # select + range_ = self.tables.data_table.c.range + data = connection.execute(select(range_ * range_)).fetchall() + eq_(data, [(self._data_obj(),)]) + + def test_difference_result(self, connection): + # insert + connection.execute( + self.tables.data_table.insert(), {"range": self._data_str()} + ) + # select + range_ = self.tables.data_table.c.range + data = connection.execute(select(range_ - range_)).fetchall() + eq_(data, [(self.extras().Multirange(),)]) + + +class _Int4MultiRangeTests: + + _col_type = INT4MULTIRANGE + _col_str = "INT4MULTIRANGE" + + def _data_str(self): + return "{[1,2), [3, 5), [9, 12)}" + + def _data_obj(self): + return self.extras().Multirange( + [ + self.extras().Range(1, 2), + self.extras().Range(3, 5), + self.extras().Range(9, 12), + ] + ) + + +class _Int8MultiRangeTests: + + _col_type = INT8MULTIRANGE + _col_str = "INT8MULTIRANGE" + + def _data_str(self): + return ( + "{[9223372036854775801,9223372036854775803)," + + "[9223372036854775805,9223372036854775807)}" + ) + + def _data_obj(self): + return self.extras().Multirange( + [ + self.extras().Range(9223372036854775801, 9223372036854775803), + self.extras().Range(9223372036854775805, 9223372036854775807), + ] + ) + + +class _NumMultiRangeTests: + + _col_type = NUMMULTIRANGE + _col_str = "NUMMULTIRANGE" + + def _data_str(self): + return "{[1.0,2.0), [3.0, 5.0), [9.0, 12.0)}" + + def _data_obj(self): + return self.extras().Multirange( + [ + self.extras().Range( + decimal.Decimal("1.0"), decimal.Decimal("2.0") + ), + self.extras().Range( + decimal.Decimal("3.0"), decimal.Decimal("5.0") + ), + self.extras().Range( + decimal.Decimal("9.0"), decimal.Decimal("12.0") + ), + ] + ) + + +class _DateMultiRangeTests: + + _col_type = DATEMULTIRANGE + _col_str = "DATEMULTIRANGE" + + def _data_str(self): + return "{[2013-03-23,2013-03-24), [2014-05-23,2014-05-24)}" + + def _data_obj(self): + return self.extras().Multirange( + [ + self.extras().Range( + datetime.date(2013, 3, 23), datetime.date(2013, 3, 24) + ), + self.extras().Range( + datetime.date(2014, 5, 23), datetime.date(2014, 5, 24) + ), + ] + ) + + +class _DateTimeMultiRangeTests: + + _col_type = TSMULTIRANGE + _col_str = "TSMULTIRANGE" + + def _data_str(self): + return ( + "{[2013-03-23 14:30,2013-03-23 23:30)," + + "[2014-05-23 14:30,2014-05-23 23:30)}" + ) + + def _data_obj(self): + return self.extras().Multirange( + [ + self.extras().Range( + datetime.datetime(2013, 3, 23, 14, 30), + datetime.datetime(2013, 3, 23, 23, 30), + ), + self.extras().Range( + datetime.datetime(2014, 5, 23, 14, 30), + datetime.datetime(2014, 5, 23, 23, 30), + ), + ] + ) + + +class _DateTimeTZMultiRangeTests: + + _col_type = TSTZMULTIRANGE + _col_str = "TSTZMULTIRANGE" + + # make sure we use one, steady timestamp with timezone pair + # for all parts of all these tests + _tstzs = None + _tstzs_delta = None + + def tstzs(self): + if self._tstzs is None: + with testing.db.connect() as connection: + lower = connection.scalar(func.current_timestamp().select()) + upper = lower + datetime.timedelta(1) + self._tstzs = (lower, upper) + return self._tstzs + + def tstzs_delta(self): + if self._tstzs_delta is None: + with testing.db.connect() as connection: + lower = connection.scalar( + func.current_timestamp().select() + ) + datetime.timedelta(3) + upper = lower + datetime.timedelta(2) + self._tstzs_delta = (lower, upper) + return self._tstzs_delta + + def _data_str(self): + tstzs_lower, tstzs_upper = self.tstzs() + tstzs_delta_lower, tstzs_delta_upper = self.tstzs_delta() + return "{{[{tl},{tu}), [{tdl},{tdu})}}".format( + tl=tstzs_lower, + tu=tstzs_upper, + tdl=tstzs_delta_lower, + tdu=tstzs_delta_upper, + ) + + def _data_obj(self): + return self.extras().Multirange( + [ + self.extras().Range(*self.tstzs()), + self.extras().Range(*self.tstzs_delta()), + ] + ) + + +class Int4MultiRangeCompilationTest( + _Int4MultiRangeTests, _MultiRangeTypeCompilation +): + pass + + +class Int4MultiRangeRoundTripTest( + _Int4MultiRangeTests, _MultiRangeTypeRoundTrip +): + pass + + +class Int8MultiRangeCompilationTest( + _Int8MultiRangeTests, _MultiRangeTypeCompilation +): + pass + + +class Int8MultiRangeRoundTripTest( + _Int8MultiRangeTests, _MultiRangeTypeRoundTrip +): + pass + + +class NumMultiRangeCompilationTest( + _NumMultiRangeTests, _MultiRangeTypeCompilation +): + pass + + +class NumMultiRangeRoundTripTest( + _NumMultiRangeTests, _MultiRangeTypeRoundTrip +): + pass + + +class DateMultiRangeCompilationTest( + _DateMultiRangeTests, _MultiRangeTypeCompilation +): + pass + + +class DateMultiRangeRoundTripTest( + _DateMultiRangeTests, _MultiRangeTypeRoundTrip +): + pass + + +class DateTimeMultiRangeCompilationTest( + _DateTimeMultiRangeTests, _MultiRangeTypeCompilation +): + pass + + +class DateTimeMultiRangeRoundTripTest( + _DateTimeMultiRangeTests, _MultiRangeTypeRoundTrip +): + pass + + +class DateTimeTZMultiRangeCompilationTest( + _DateTimeTZMultiRangeTests, _MultiRangeTypeCompilation +): + pass + + +class DateTimeTZRMultiangeRoundTripTest( + _DateTimeTZMultiRangeTests, _MultiRangeTypeRoundTrip +): + pass + + class JSONTest(AssertsCompiledSQL, fixtures.TestBase): __dialect__ = "postgresql" |
