From 0b239579f03c82f7669d77c238e4fda8638fb9c3 Mon Sep 17 00:00:00 2001 From: Lele Gaifax Date: Sun, 27 Nov 2022 11:28:51 -0500 Subject: Add value-level hooks for SQL type detection; apply to Range Added additional type-detection for the new PostgreSQL :class:`_postgresql.Range` type, where previous cases that allowed the psycopg2-native range objects to be received directly by the DBAPI without SQLAlchemy intercepting them stopped working, as we now have our own value object. The :class:`_postgresql.Range` object has been enhanced such that SQLAlchemy Core detects it in otherwise ambiguous situations (such as comparison to dates) and applies appropriate bind handlers. Pull request courtesy Lele Gaifax. Fixes: #8884 Closes: #8886 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8886 Pull-request-sha: 6e95e08a30597d3735ab38f2f1a2ccabd968852c Change-Id: I3ca277c826dcf4b5644f44eb251345b439a84ee4 --- lib/sqlalchemy/dialects/postgresql/ranges.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) (limited to 'lib/sqlalchemy/dialects/postgresql/ranges.py') diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index a4c39d063..6f13d462a 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -11,6 +11,7 @@ import dataclasses from datetime import date from datetime import datetime from datetime import timedelta +from decimal import Decimal from typing import Any from typing import Generic from typing import Optional @@ -84,6 +85,10 @@ class Range(Generic[_T]): def __bool__(self) -> bool: return self.empty + @property + def __sa_type_engine__(self): + return AbstractRange() + def _contains_value(self, value: _T) -> bool: "Check whether this range contains the given `value`." @@ -622,6 +627,21 @@ class AbstractRange(sqltypes.TypeEngine): else: return super().adapt(impltype) + def _resolve_for_literal(self, value): + spec = value.lower if value.lower is not None else value.upper + + if isinstance(spec, int): + return INT8RANGE() + elif isinstance(spec, (Decimal, float)): + return NUMRANGE() + elif isinstance(spec, datetime): + return TSRANGE() if not spec.tzinfo else TSTZRANGE() + elif isinstance(spec, date): + return DATERANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + class comparator_factory(sqltypes.Concatenable.Comparator): """Define comparison operations for range types.""" -- cgit v1.2.1