diff options
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 37 |
1 files changed, 24 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index c21b102c7..6ad29218f 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -26,10 +26,11 @@ to stay the same in future releases. """ import re +import datetime from sqlalchemy import util, exceptions from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes - +functions = None __all__ = [ 'Alias', 'ClauseElement', @@ -596,7 +597,7 @@ def literal(value, type_=None): bind-parameter translation for this literal. """ - return _BindParamClause('literal', value, type_=type_, unique=True) + return _BindParamClause(None, value, type_=type_, unique=True) def label(name, obj): """Return a [sqlalchemy.sql.expression#_Label] object for the given [sqlalchemy.sql.expression#ColumnElement]. @@ -766,6 +767,14 @@ class _FunctionGenerator(object): def __call__(self, *c, **kwargs): o = self.opts.copy() o.update(kwargs) + if len(self.__names) == 1: + global functions + if functions is None: + from sqlalchemy.sql import functions + func = getattr(functions, self.__names[-1].lower(), None) + if func is not None: + return func(*c, **o) + return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) func = _FunctionGenerator() @@ -798,7 +807,7 @@ def _literal_as_column(element): else: return element -def _literal_as_binds(element, name='literal', type_=None): +def _literal_as_binds(element, name=None, type_=None): if isinstance(element, Operators): return element.expression_element() elif _is_literal(element): @@ -1344,7 +1353,7 @@ class _CompareMixin(ColumnOperators): return lambda other: self.__operate(operator, other) def _bind_param(self, obj): - return _BindParamClause('literal', obj, type_=self.type, unique=True) + return _BindParamClause(None, obj, type_=self.type, unique=True) def _check_literal(self, other): if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): @@ -1762,6 +1771,10 @@ class _BindParamClause(ClauseElement, _CompareMixin): unicode : sqltypes.NCHAR, int : sqltypes.Integer, float : sqltypes.Numeric, + datetime.date : sqltypes.Date, + datetime.datetime : sqltypes.DateTime, + datetime.time : sqltypes.Time, + datetime.timedelta : sqltypes.Interval, type(None):sqltypes.NullType } @@ -1997,11 +2010,12 @@ class _Function(_CalculatedClause, FromClause): def __init__(self, name, *clauses, **kwargs): self.packagenames = kwargs.get('packagenames', None) or [] self.oid_column = None - kwargs['operator'] = operators.comma_op - _CalculatedClause.__init__(self, name, **kwargs) - for c in clauses: - self.append(c) - + self.name = name + self._bind = kwargs.get('bind', None) + args = [_literal_as_binds(c, self.name) for c in clauses] + self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args).self_group() + self.type = sqltypes.to_instance(kwargs.get('type_', None)) + key = property(lambda self:self.name) columns = property(lambda self:[self]) @@ -2012,9 +2026,6 @@ class _Function(_CalculatedClause, FromClause): def get_children(self, **kwargs): return _CalculatedClause.get_children(self, **kwargs) - def append(self, clause): - self.clauses.append(_literal_as_binds(clause, self.name)) - class _Cast(ColumnElement): @@ -2984,7 +2995,7 @@ class Select(_SelectBaseMixin, FromClause): if from_obj: self._froms = util.Set([ _is_literal(f) and _TextFromClause(f) or f - for f in from_obj + for f in util.to_list(from_obj) ]) else: self._froms = util.Set() |