summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/expression.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r--lib/sqlalchemy/sql/expression.py37
1 files changed, 24 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index c21b102c7..6ad29218f 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -26,10 +26,11 @@ to stay the same in future releases.
"""
import re
+import datetime
from sqlalchemy import util, exceptions
from sqlalchemy.sql import operators, visitors
from sqlalchemy import types as sqltypes
-
+functions = None
__all__ = [
'Alias', 'ClauseElement',
@@ -596,7 +597,7 @@ def literal(value, type_=None):
bind-parameter translation for this literal.
"""
- return _BindParamClause('literal', value, type_=type_, unique=True)
+ return _BindParamClause(None, value, type_=type_, unique=True)
def label(name, obj):
"""Return a [sqlalchemy.sql.expression#_Label] object for the given [sqlalchemy.sql.expression#ColumnElement].
@@ -766,6 +767,14 @@ class _FunctionGenerator(object):
def __call__(self, *c, **kwargs):
o = self.opts.copy()
o.update(kwargs)
+ if len(self.__names) == 1:
+ global functions
+ if functions is None:
+ from sqlalchemy.sql import functions
+ func = getattr(functions, self.__names[-1].lower(), None)
+ if func is not None:
+ return func(*c, **o)
+
return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
func = _FunctionGenerator()
@@ -798,7 +807,7 @@ def _literal_as_column(element):
else:
return element
-def _literal_as_binds(element, name='literal', type_=None):
+def _literal_as_binds(element, name=None, type_=None):
if isinstance(element, Operators):
return element.expression_element()
elif _is_literal(element):
@@ -1344,7 +1353,7 @@ class _CompareMixin(ColumnOperators):
return lambda other: self.__operate(operator, other)
def _bind_param(self, obj):
- return _BindParamClause('literal', obj, type_=self.type, unique=True)
+ return _BindParamClause(None, obj, type_=self.type, unique=True)
def _check_literal(self, other):
if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
@@ -1762,6 +1771,10 @@ class _BindParamClause(ClauseElement, _CompareMixin):
unicode : sqltypes.NCHAR,
int : sqltypes.Integer,
float : sqltypes.Numeric,
+ datetime.date : sqltypes.Date,
+ datetime.datetime : sqltypes.DateTime,
+ datetime.time : sqltypes.Time,
+ datetime.timedelta : sqltypes.Interval,
type(None):sqltypes.NullType
}
@@ -1997,11 +2010,12 @@ class _Function(_CalculatedClause, FromClause):
def __init__(self, name, *clauses, **kwargs):
self.packagenames = kwargs.get('packagenames', None) or []
self.oid_column = None
- kwargs['operator'] = operators.comma_op
- _CalculatedClause.__init__(self, name, **kwargs)
- for c in clauses:
- self.append(c)
-
+ self.name = name
+ self._bind = kwargs.get('bind', None)
+ args = [_literal_as_binds(c, self.name) for c in clauses]
+ self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args).self_group()
+ self.type = sqltypes.to_instance(kwargs.get('type_', None))
+
key = property(lambda self:self.name)
columns = property(lambda self:[self])
@@ -2012,9 +2026,6 @@ class _Function(_CalculatedClause, FromClause):
def get_children(self, **kwargs):
return _CalculatedClause.get_children(self, **kwargs)
- def append(self, clause):
- self.clauses.append(_literal_as_binds(clause, self.name))
-
class _Cast(ColumnElement):
@@ -2984,7 +2995,7 @@ class Select(_SelectBaseMixin, FromClause):
if from_obj:
self._froms = util.Set([
_is_literal(f) and _TextFromClause(f) or f
- for f in from_obj
+ for f in util.to_list(from_obj)
])
else:
self._froms = util.Set()