summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/psycopg.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/psycopg.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg.py94
1 files changed, 94 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py
index 414976a62..633357a74 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py
@@ -57,9 +57,14 @@ release of SQLAlchemy 2.0, however.
Further documentation is available there.
""" # noqa
+from __future__ import annotations
+
import logging
import re
+from typing import cast
+from typing import TYPE_CHECKING
+from . import ranges
from ._psycopg_common import _PGDialect_common_psycopg
from ._psycopg_common import _PGExecutionContext_common_psycopg
from .base import INTERVAL
@@ -75,6 +80,9 @@ from ...sql import sqltypes
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
+if TYPE_CHECKING:
+ from typing import Iterable
+
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
@@ -154,6 +162,78 @@ class _PGBoolean(sqltypes.Boolean):
render_bind_cast = True
+class _PsycopgRange(ranges.AbstractRange):
+ def bind_processor(self, dialect):
+ Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
+
+ NoneType = type(None)
+
+ def to_range(value):
+ if not isinstance(value, (str, NoneType)):
+ value = Range(
+ value.lower, value.upper, value.bounds, value.empty
+ )
+ return value
+
+ return to_range
+
+ def result_processor(self, dialect, coltype):
+ def to_range(value):
+ if value is not None:
+ value = ranges.Range(
+ value._lower,
+ value._upper,
+ bounds=value._bounds if value._bounds else "[)",
+ empty=not value._bounds,
+ )
+ return value
+
+ return to_range
+
+
+class _PsycopgMultiRange(ranges.AbstractMultiRange):
+ def bind_processor(self, dialect):
+ Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
+ Multirange = cast(PGDialect_psycopg, dialect)._psycopg_Multirange
+
+ NoneType = type(None)
+
+ def to_range(value):
+ if isinstance(value, (str, NoneType)):
+ return value
+
+ return Multirange(
+ [
+ Range(
+ element.lower,
+ element.upper,
+ element.bounds,
+ element.empty,
+ )
+ for element in cast("Iterable[ranges.Range]", value)
+ ]
+ )
+
+ return to_range
+
+ def result_processor(self, dialect, coltype):
+ def to_range(value):
+ if value is not None:
+ value = [
+ ranges.Range(
+ elem._lower,
+ elem._upper,
+ bounds=elem._bounds if elem._bounds else "[)",
+ empty=not elem._bounds,
+ )
+ for elem in value
+ ]
+
+ return value
+
+ return to_range
+
+
class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg):
pass
@@ -204,6 +284,8 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
sqltypes.Integer: _PGInteger,
sqltypes.SmallInteger: _PGSmallInteger,
sqltypes.BigInteger: _PGBigInteger,
+ ranges.AbstractRange: _PsycopgRange,
+ ranges.AbstractMultiRange: _PsycopgMultiRange,
},
)
@@ -314,6 +396,18 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
return TransactionStatus
+ @util.memoized_property
+ def _psycopg_Range(self):
+ from psycopg.types.range import Range
+
+ return Range
+
+ @util.memoized_property
+ def _psycopg_Multirange(self):
+ from psycopg.types.multirange import Multirange
+
+ return Multirange
+
def _do_isolation_level(self, connection, autocommit, isolation_level):
connection.autocommit = autocommit
connection.isolation_level = isolation_level