From fce1d954aa57feca9c163f9d8cf66df5e8ce7b65 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 4 Aug 2022 10:27:59 -0400 Subject: implement PG ranges/multiranges agnostically Ranges now work using a new Range object, multiranges as lists of Range objects (this is what asyncpg does. not sure why psycopg has a "Multirange" type). psycopg, psycopg2, and asyncpg are currently supported. It's not clear how to make ranges work with pg8000, likely needs string conversion; this is straightforward with the new archicture and can be added later. Fixes: #8178 Change-Id: Iab8d8382873d5c14199adbe3f09fd0dc17e2b9f1 --- lib/sqlalchemy/dialects/postgresql/asyncpg.py | 95 +++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) (limited to 'lib/sqlalchemy/dialects/postgresql/asyncpg.py') diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index d6385a5d6..38f8fddee 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -119,14 +119,19 @@ client using this setting passed to :func:`_asyncio.create_async_engine`:: """ # noqa +from __future__ import annotations + import collections import collections.abc as collections_abc import decimal import json as _py_json import re import time +from typing import cast +from typing import TYPE_CHECKING from . import json +from . import ranges from .base import _DECIMAL_TYPES from .base import _FLOAT_TYPES from .base import _INT_TYPES @@ -148,6 +153,9 @@ from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only +if TYPE_CHECKING: + from typing import Iterable + class AsyncpgString(sqltypes.String): render_bind_cast = True @@ -278,6 +286,91 @@ class AsyncpgCHAR(sqltypes.CHAR): render_bind_cast = True +class _AsyncpgRange(ranges.AbstractRange): + def bind_processor(self, dialect): + Range = dialect.dbapi.asyncpg.Range + + NoneType = type(None) + + def to_range(value): + if not isinstance(value, (str, NoneType)): + value = Range( + value.lower, + value.upper, + lower_inc=value.bounds[0] == "[", + upper_inc=value.bounds[1] == "]", + empty=value.empty, + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + empty = value.isempty + value = ranges.Range( + value.lower, + value.upper, + bounds=f"{'[' if empty or value.lower_inc else '('}" # type: ignore # noqa: E501 + f"{']' if not empty and value.upper_inc else ')'}", + empty=empty, + ) + return value + + return to_range + + +class _AsyncpgMultiRange(ranges.AbstractMultiRange): + def bind_processor(self, dialect): + Range = dialect.dbapi.asyncpg.Range + + NoneType = type(None) + + def to_range(value): + if isinstance(value, (str, NoneType)): + return value + + def to_range(value): + if not isinstance(value, (str, NoneType)): + value = Range( + value.lower, + value.upper, + lower_inc=value.bounds[0] == "[", + upper_inc=value.bounds[1] == "]", + empty=value.empty, + ) + return value + + return [ + to_range(element) + for element in cast("Iterable[ranges.Range]", value) + ] + + return to_range + + def result_processor(self, dialect, coltype): + def to_range_array(value): + def to_range(rvalue): + if rvalue is not None: + empty = rvalue.isempty + rvalue = ranges.Range( + rvalue.lower, + rvalue.upper, + bounds=f"{'[' if empty or rvalue.lower_inc else '('}" # type: ignore # noqa: E501 + f"{']' if not empty and rvalue.upper_inc else ')'}", + empty=empty, + ) + return rvalue + + if value is not None: + value = [to_range(elem) for elem in value] + + return value + + return to_range_array + + class PGExecutionContext_asyncpg(PGExecutionContext): def handle_dbapi_exception(self, e): if isinstance( @@ -828,6 +921,8 @@ class PGDialect_asyncpg(PGDialect): OID: AsyncpgOID, REGCLASS: AsyncpgREGCLASS, sqltypes.CHAR: AsyncpgCHAR, + ranges.AbstractRange: _AsyncpgRange, + ranges.AbstractMultiRange: _AsyncpgMultiRange, }, ) is_async = True -- cgit v1.2.1