diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/psycopg.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg.py | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 414976a62..633357a74 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -57,9 +57,14 @@ release of SQLAlchemy 2.0, however. Further documentation is available there. """ # noqa +from __future__ import annotations + import logging import re +from typing import cast +from typing import TYPE_CHECKING +from . import ranges from ._psycopg_common import _PGDialect_common_psycopg from ._psycopg_common import _PGExecutionContext_common_psycopg from .base import INTERVAL @@ -75,6 +80,9 @@ from ...sql import sqltypes from ...util.concurrency import await_fallback from ...util.concurrency import await_only +if TYPE_CHECKING: + from typing import Iterable + logger = logging.getLogger("sqlalchemy.dialects.postgresql") @@ -154,6 +162,78 @@ class _PGBoolean(sqltypes.Boolean): render_bind_cast = True +class _PsycopgRange(ranges.AbstractRange): + def bind_processor(self, dialect): + Range = cast(PGDialect_psycopg, dialect)._psycopg_Range + + NoneType = type(None) + + def to_range(value): + if not isinstance(value, (str, NoneType)): + value = Range( + value.lower, value.upper, value.bounds, value.empty + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + value = ranges.Range( + value._lower, + value._upper, + bounds=value._bounds if value._bounds else "[)", + empty=not value._bounds, + ) + return value + + return to_range + + +class _PsycopgMultiRange(ranges.AbstractMultiRange): + def bind_processor(self, dialect): + Range = cast(PGDialect_psycopg, dialect)._psycopg_Range + Multirange = cast(PGDialect_psycopg, dialect)._psycopg_Multirange + + NoneType = type(None) + + def to_range(value): + if isinstance(value, (str, NoneType)): + return value + + return Multirange( + [ + Range( + element.lower, + element.upper, + element.bounds, + element.empty, + ) + for element in cast("Iterable[ranges.Range]", value) + ] + ) + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + value = [ + ranges.Range( + elem._lower, + elem._upper, + bounds=elem._bounds if elem._bounds else "[)", + empty=not elem._bounds, + ) + for elem in value + ] + + return value + + return to_range + + class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg): pass @@ -204,6 +284,8 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): sqltypes.Integer: _PGInteger, sqltypes.SmallInteger: _PGSmallInteger, sqltypes.BigInteger: _PGBigInteger, + ranges.AbstractRange: _PsycopgRange, + ranges.AbstractMultiRange: _PsycopgMultiRange, }, ) @@ -314,6 +396,18 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): return TransactionStatus + @util.memoized_property + def _psycopg_Range(self): + from psycopg.types.range import Range + + return Range + + @util.memoized_property + def _psycopg_Multirange(self): + from psycopg.types.multirange import Multirange + + return Multirange + def _do_isolation_level(self, connection, autocommit, isolation_level): connection.autocommit = autocommit connection.isolation_level = isolation_level |