summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py10
-rw-r--r--lib/sqlalchemy/sql/util.py86
-rw-r--r--lib/sqlalchemy/types.py4
3 files changed, 98 insertions, 2 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 7a9e2e710..a71984be4 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -322,8 +322,16 @@ class PGCompiler(compiler.SQLCompiler):
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
+ affinity = sql_util.determine_date_affinity(extract.expr)
+
+ casts = {sqltypes.Date:'date', sqltypes.DateTime:'timestamp', sqltypes.Interval:'interval', sqltypes.Time:'time'}
+ cast = casts.get(affinity, None)
+ if isinstance(extract.expr, sql.ColumnElement) and cast is not None:
+ expr = extract.expr.op('::')(sql.literal_column(cast))
+ else:
+ expr = extract.expr
return "EXTRACT(%s FROM %s)" % (
- field, self.process(extract.expr.op('::')(sql.literal_column('timestamp'))))
+ field, self.process(expr))
class PGDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 075d8c7ef..7bcc8e7d7 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -1,4 +1,4 @@
-from sqlalchemy import exc, schema, topological, util, sql
+from sqlalchemy import exc, schema, topological, util, sql, types as sqltypes
from sqlalchemy.sql import expression, operators, visitors
from itertools import chain
@@ -46,6 +46,90 @@ def find_join_source(clauses, join_to):
else:
return None, None
+_date_affinities = None
+def determine_date_affinity(expr):
+ """Given an expression, determine if it returns 'interval', 'date', or 'datetime'.
+
+ the PG dialect uses this to generate the extract() function.
+
+ It's less than ideal since it basically needs to duplicate PG's
+ date arithmetic rules.
+
+ Rules are based on http://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ Returns None if operators other than + or - are detected as well as types
+ outside of those above.
+
+ """
+
+ global _date_affinities
+ if _date_affinities is None:
+ Date, DateTime, Integer, \
+ Numeric, Interval, Time = \
+ sqltypes.Date, sqltypes.DateTime,\
+ sqltypes.Integer, sqltypes.Numeric,\
+ sqltypes.Interval, sqltypes.Time
+
+ _date_affinities = {
+ operators.add:{
+ (Date, Integer):Date,
+ (Date, Interval):DateTime,
+ (Date, Time):DateTime,
+ (Interval, Interval):Interval,
+ (DateTime, Interval):DateTime,
+ (Interval, Time):Time,
+ },
+ operators.sub:{
+ (Date, Integer):Date,
+ (Date, Interval):DateTime,
+ (Time, Time):Interval,
+ (Time, Interval):Time,
+ (DateTime, Interval):DateTime,
+ (Interval, Interval):Interval,
+ (DateTime, DateTime):Interval,
+ },
+ operators.mul:{
+ (Integer, Interval):Interval,
+ (Interval, Numeric):Interval,
+ },
+ operators.div: {
+ (Interval, Numeric):Interval
+ }
+ }
+
+ if isinstance(expr, expression._BinaryExpression):
+ if expr.operator not in _date_affinities:
+ return None
+
+ left_affin, right_affin = \
+ determine_date_affinity(expr.left), \
+ determine_date_affinity(expr.right)
+
+ if operators.is_commutative(expr.operator):
+ key = tuple(sorted([left_affin, right_affin], key=lambda cls:cls.__name__))
+ else:
+ key = (left_affin, right_affin)
+
+ lookup = _date_affinities[expr.operator]
+ return lookup.get(key, None)
+
+ # work around the fact that expressions put the wrong type
+ # on generated bind params when its "datetime + timedelta"
+ # and similar
+ if isinstance(expr, expression._BindParamClause):
+ type_ = sqltypes.type_map.get(type(expr.value), sqltypes.NullType)()
+ else:
+ type_ = expr.type
+
+ affinities = set([sqltypes.Date, sqltypes.DateTime,
+ sqltypes.Interval, sqltypes.Time, sqltypes.Integer])
+
+ if type_ is not None and type_._type_affinity in affinities:
+ return type_._type_affinity
+ else:
+ return None
+
+
def find_tables(clause, check_columns=False,
include_aliases=False, include_joins=False,
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 3c42de2b8..c217e594f 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -1213,6 +1213,10 @@ class Interval(TypeDecorator):
return value - epoch
return process
+ @property
+ def _type_affinity(self):
+ return Interval
+
class FLOAT(Float):
"""The SQL FLOAT type."""