diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-08 18:27:35 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-08 18:27:35 -0500 |
commit | b9cc0ef3260ab5c93af3c011db7d062ba15252fe (patch) | |
tree | 13144102923244abb1f6b89f79fb9627054a3dbd /lib/sqlalchemy/sql/expression.py | |
parent | 35254ca25a0d77f9be7ce944d40e5d3ecc641be6 (diff) | |
download | sqlalchemy-b9cc0ef3260ab5c93af3c011db7d062ba15252fe.tar.gz |
working on getting operators/left hand type awareness into the "bind" coercion. this system has to be figured out somehow
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 55 |
1 files changed, 29 insertions, 26 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 1c3961f1f..49ec34ab2 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1443,7 +1443,7 @@ class _CompareMixin(ColumnOperators): else: raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL") else: - obj = self._check_literal(obj) + obj = self._check_literal(op, obj) if reverse: return _BinaryExpression(obj, @@ -1459,7 +1459,7 @@ class _CompareMixin(ColumnOperators): negate=negate, modifiers=kwargs) def __operate(self, op, obj, reverse=False): - obj = self._check_literal(obj) + obj = self._check_literal(op, obj) if reverse: left, right = obj, self @@ -1532,7 +1532,7 @@ class _CompareMixin(ColumnOperators): "in() function accepts either a list of non-selectable values, " "or a selectable: %r" % o) else: - o = self._bind_param(o) + o = self._bind_param(op, o) args.append(o) if len(args) == 0: @@ -1558,7 +1558,7 @@ class _CompareMixin(ColumnOperators): # use __radd__ to force string concat behavior return self.__compare( operators.like_op, - literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(other)), + literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(operators.like_op, other)), escape=escape) def endswith(self, other, escape=None): @@ -1566,7 +1566,7 @@ class _CompareMixin(ColumnOperators): return self.__compare( operators.like_op, - literal_column("'%'", type_=sqltypes.String) + self._check_literal(other), + literal_column("'%'", type_=sqltypes.String) + self._check_literal(operators.like_op, other), escape=escape) def contains(self, other, escape=None): @@ -1575,7 +1575,7 @@ class _CompareMixin(ColumnOperators): return self.__compare( operators.like_op, literal_column("'%'", type_=sqltypes.String) + - self._check_literal(other) + + self._check_literal(operators.like_op, other) + literal_column("'%'", type_=sqltypes.String), escape=escape) @@ -1585,7 +1585,7 @@ class _CompareMixin(ColumnOperators): The allowed contents of ``other`` are database backend specific. """ - return self.__compare(operators.match_op, self._check_literal(other)) + return self.__compare(operators.match_op, self._check_literal(operators.match_op, other)) def label(self, name): """Produce a column label, i.e. ``<columnname> AS <name>``. @@ -1615,8 +1615,8 @@ class _CompareMixin(ColumnOperators): return _BinaryExpression( self, ClauseList( - self._check_literal(cleft), - self._check_literal(cright), + self._check_literal(operators.and_, cleft), + self._check_literal(operators.and_, cright), operator=operators.and_, group=False), operators.between_op) @@ -1651,17 +1651,18 @@ class _CompareMixin(ColumnOperators): """ return lambda other: self.__operate(operator, other) - def _bind_param(self, obj): - return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) - def _check_literal(self, other): - if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): + def _check_literal(self, operator, other): + if isinstance(other, _BindParamClause) and \ + isinstance(other.type, sqltypes.NullType): other.type = self.type return other elif hasattr(other, '__clause_element__'): return other.__clause_element__() elif not isinstance(other, ClauseElement): - return self._bind_param(other) + return self._bind_param(operator, other) elif isinstance(other, (_SelectBaseMixin, Alias)): return other.as_scalar() else: @@ -2108,7 +2109,8 @@ class _BindParamClause(ColumnElement): def __init__(self, key, value, type_=None, unique=False, isoutparam=False, required=False, - _fallback_type=None): + _compared_to_operator=None, + _compared_to_type=None): """Construct a _BindParamClause. key @@ -2154,9 +2156,10 @@ class _BindParamClause(ColumnElement): self.required = required if type_ is None: - self.type = sqltypes.type_map.get(type(value), _fallback_type or sqltypes.NULLTYPE) - if _fallback_type and _fallback_type._type_affinity == self.type._type_affinity: - self.type = _fallback_type + if _compared_to_type is not None: + self.type = _compared_to_type._coerce_compared_value(_compared_to_operator, value) + else: + self.type = sqltypes.NULLTYPE elif isinstance(type_, type): self.type = type_() else: @@ -2434,9 +2437,9 @@ class _Tuple(ClauseList, ColumnElement): def _select_iterable(self): return (self, ) - def _bind_param(self, obj): + def _bind_param(self, operator, obj): return _Tuple(*[ - _BindParamClause(None, o, _fallback_type=self.type, unique=True) + _BindParamClause(None, o, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) for o in obj ]).self_group() @@ -2538,8 +2541,8 @@ class FunctionElement(Executable, ColumnElement, FromClause): def execute(self): return self.select().execute() - def _bind_param(self, obj): - return _BindParamClause(None, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) class Function(FunctionElement): @@ -2555,8 +2558,8 @@ class Function(FunctionElement): FunctionElement.__init__(self, *clauses, **kw) - def _bind_param(self, obj): - return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) class _Cast(ColumnElement): @@ -3165,8 +3168,8 @@ class ColumnClause(_Immutable, ColumnElement): else: return [] - def _bind_param(self, obj): - return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True) + def _bind_param(self, operator, obj): + return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True) def _make_proxy(self, selectable, name=None, attach=True): # propagate the "is_literal" flag only if we are keeping our name, |