diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-12-29 02:41:16 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-12-29 02:41:16 +0000 |
commit | cf7c80b3f4a2ed9e2d2d2dd814839b9f50048815 (patch) | |
tree | 9ff011d37066cf874cc1c0f40415270b4a416b8e /lib/sqlalchemy/sql/util.py | |
parent | a572e39871a73588b19a8ce9e81a4b42148b7018 (diff) | |
download | sqlalchemy-cf7c80b3f4a2ed9e2d2d2dd814839b9f50048815.tar.gz |
- merge r6586 from 0.5 branch, for [ticket:1647]
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 86 |
1 files changed, 85 insertions, 1 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 075d8c7ef..7bcc8e7d7 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,4 +1,4 @@ -from sqlalchemy import exc, schema, topological, util, sql +from sqlalchemy import exc, schema, topological, util, sql, types as sqltypes from sqlalchemy.sql import expression, operators, visitors from itertools import chain @@ -46,6 +46,90 @@ def find_join_source(clauses, join_to): else: return None, None +_date_affinities = None +def determine_date_affinity(expr): + """Given an expression, determine if it returns 'interval', 'date', or 'datetime'. + + the PG dialect uses this to generate the extract() function. + + It's less than ideal since it basically needs to duplicate PG's + date arithmetic rules. + + Rules are based on http://www.postgresql.org/docs/current/static/functions-datetime.html. + + Returns None if operators other than + or - are detected as well as types + outside of those above. + + """ + + global _date_affinities + if _date_affinities is None: + Date, DateTime, Integer, \ + Numeric, Interval, Time = \ + sqltypes.Date, sqltypes.DateTime,\ + sqltypes.Integer, sqltypes.Numeric,\ + sqltypes.Interval, sqltypes.Time + + _date_affinities = { + operators.add:{ + (Date, Integer):Date, + (Date, Interval):DateTime, + (Date, Time):DateTime, + (Interval, Interval):Interval, + (DateTime, Interval):DateTime, + (Interval, Time):Time, + }, + operators.sub:{ + (Date, Integer):Date, + (Date, Interval):DateTime, + (Time, Time):Interval, + (Time, Interval):Time, + (DateTime, Interval):DateTime, + (Interval, Interval):Interval, + (DateTime, DateTime):Interval, + }, + operators.mul:{ + (Integer, Interval):Interval, + (Interval, Numeric):Interval, + }, + operators.div: { + (Interval, Numeric):Interval + } + } + + if isinstance(expr, expression._BinaryExpression): + if expr.operator not in _date_affinities: + return None + + left_affin, right_affin = \ + determine_date_affinity(expr.left), \ + determine_date_affinity(expr.right) + + if operators.is_commutative(expr.operator): + key = tuple(sorted([left_affin, right_affin], key=lambda cls:cls.__name__)) + else: + key = (left_affin, right_affin) + + lookup = _date_affinities[expr.operator] + return lookup.get(key, None) + + # work around the fact that expressions put the wrong type + # on generated bind params when its "datetime + timedelta" + # and similar + if isinstance(expr, expression._BindParamClause): + type_ = sqltypes.type_map.get(type(expr.value), sqltypes.NullType)() + else: + type_ = expr.type + + affinities = set([sqltypes.Date, sqltypes.DateTime, + sqltypes.Interval, sqltypes.Time, sqltypes.Integer]) + + if type_ is not None and type_._type_affinity in affinities: + return type_._type_affinity + else: + return None + + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, |