diff options
author | Lele Gaifax <lele@metapensiero.it> | 2022-11-27 11:28:51 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-11-29 17:11:38 -0500 |
commit | 0b239579f03c82f7669d77c238e4fda8638fb9c3 (patch) | |
tree | ebe836c6d9f60362c4824843478122c7f725c2bd /lib/sqlalchemy/dialects/postgresql/ranges.py | |
parent | 61443aa62bbef158274ae393db399fec7f054c2d (diff) | |
download | sqlalchemy-0b239579f03c82f7669d77c238e4fda8638fb9c3.tar.gz |
Add value-level hooks for SQL type detection; apply to Range
Added additional type-detection for the new PostgreSQL
:class:`_postgresql.Range` type, where previous cases that allowed the
psycopg2-native range objects to be received directly by the DBAPI without
SQLAlchemy intercepting them stopped working, as we now have our own value
object. The :class:`_postgresql.Range` object has been enhanced such that
SQLAlchemy Core detects it in otherwise ambiguous situations (such as
comparison to dates) and applies appropriate bind handlers. Pull request
courtesy Lele Gaifax.
Fixes: #8884
Closes: #8886
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8886
Pull-request-sha: 6e95e08a30597d3735ab38f2f1a2ccabd968852c
Change-Id: I3ca277c826dcf4b5644f44eb251345b439a84ee4
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/ranges.py')
-rw-r--r-- | lib/sqlalchemy/dialects/postgresql/ranges.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index a4c39d063..6f13d462a 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -11,6 +11,7 @@ import dataclasses from datetime import date from datetime import datetime from datetime import timedelta +from decimal import Decimal from typing import Any from typing import Generic from typing import Optional @@ -84,6 +85,10 @@ class Range(Generic[_T]): def __bool__(self) -> bool: return self.empty + @property + def __sa_type_engine__(self): + return AbstractRange() + def _contains_value(self, value: _T) -> bool: "Check whether this range contains the given `value`." @@ -622,6 +627,21 @@ class AbstractRange(sqltypes.TypeEngine): else: return super().adapt(impltype) + def _resolve_for_literal(self, value): + spec = value.lower if value.lower is not None else value.upper + + if isinstance(spec, int): + return INT8RANGE() + elif isinstance(spec, (Decimal, float)): + return NUMRANGE() + elif isinstance(spec, datetime): + return TSRANGE() if not spec.tzinfo else TSTZRANGE() + elif isinstance(spec, date): + return DATERANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + class comparator_factory(sqltypes.Concatenable.Comparator): """Define comparison operations for range types.""" |