summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/expression.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-03-08 18:27:35 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2010-03-08 18:27:35 -0500
commitb9cc0ef3260ab5c93af3c011db7d062ba15252fe (patch)
tree13144102923244abb1f6b89f79fb9627054a3dbd /lib/sqlalchemy/sql/expression.py
parent35254ca25a0d77f9be7ce944d40e5d3ecc641be6 (diff)
downloadsqlalchemy-b9cc0ef3260ab5c93af3c011db7d062ba15252fe.tar.gz
working on getting operators/left hand type awareness into the "bind" coercion. this system has to be figured out somehow
Diffstat (limited to 'lib/sqlalchemy/sql/expression.py')
-rw-r--r--lib/sqlalchemy/sql/expression.py55
1 files changed, 29 insertions, 26 deletions
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 1c3961f1f..49ec34ab2 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -1443,7 +1443,7 @@ class _CompareMixin(ColumnOperators):
else:
raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL")
else:
- obj = self._check_literal(obj)
+ obj = self._check_literal(op, obj)
if reverse:
return _BinaryExpression(obj,
@@ -1459,7 +1459,7 @@ class _CompareMixin(ColumnOperators):
negate=negate, modifiers=kwargs)
def __operate(self, op, obj, reverse=False):
- obj = self._check_literal(obj)
+ obj = self._check_literal(op, obj)
if reverse:
left, right = obj, self
@@ -1532,7 +1532,7 @@ class _CompareMixin(ColumnOperators):
"in() function accepts either a list of non-selectable values, "
"or a selectable: %r" % o)
else:
- o = self._bind_param(o)
+ o = self._bind_param(op, o)
args.append(o)
if len(args) == 0:
@@ -1558,7 +1558,7 @@ class _CompareMixin(ColumnOperators):
# use __radd__ to force string concat behavior
return self.__compare(
operators.like_op,
- literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(other)),
+ literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(operators.like_op, other)),
escape=escape)
def endswith(self, other, escape=None):
@@ -1566,7 +1566,7 @@ class _CompareMixin(ColumnOperators):
return self.__compare(
operators.like_op,
- literal_column("'%'", type_=sqltypes.String) + self._check_literal(other),
+ literal_column("'%'", type_=sqltypes.String) + self._check_literal(operators.like_op, other),
escape=escape)
def contains(self, other, escape=None):
@@ -1575,7 +1575,7 @@ class _CompareMixin(ColumnOperators):
return self.__compare(
operators.like_op,
literal_column("'%'", type_=sqltypes.String) +
- self._check_literal(other) +
+ self._check_literal(operators.like_op, other) +
literal_column("'%'", type_=sqltypes.String),
escape=escape)
@@ -1585,7 +1585,7 @@ class _CompareMixin(ColumnOperators):
The allowed contents of ``other`` are database backend specific.
"""
- return self.__compare(operators.match_op, self._check_literal(other))
+ return self.__compare(operators.match_op, self._check_literal(operators.match_op, other))
def label(self, name):
"""Produce a column label, i.e. ``<columnname> AS <name>``.
@@ -1615,8 +1615,8 @@ class _CompareMixin(ColumnOperators):
return _BinaryExpression(
self,
ClauseList(
- self._check_literal(cleft),
- self._check_literal(cright),
+ self._check_literal(operators.and_, cleft),
+ self._check_literal(operators.and_, cright),
operator=operators.and_,
group=False),
operators.between_op)
@@ -1651,17 +1651,18 @@ class _CompareMixin(ColumnOperators):
"""
return lambda other: self.__operate(operator, other)
- def _bind_param(self, obj):
- return _BindParamClause(None, obj, _fallback_type=self.type, unique=True)
+ def _bind_param(self, operator, obj):
+ return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
- def _check_literal(self, other):
- if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
+ def _check_literal(self, operator, other):
+ if isinstance(other, _BindParamClause) and \
+ isinstance(other.type, sqltypes.NullType):
other.type = self.type
return other
elif hasattr(other, '__clause_element__'):
return other.__clause_element__()
elif not isinstance(other, ClauseElement):
- return self._bind_param(other)
+ return self._bind_param(operator, other)
elif isinstance(other, (_SelectBaseMixin, Alias)):
return other.as_scalar()
else:
@@ -2108,7 +2109,8 @@ class _BindParamClause(ColumnElement):
def __init__(self, key, value, type_=None, unique=False,
isoutparam=False, required=False,
- _fallback_type=None):
+ _compared_to_operator=None,
+ _compared_to_type=None):
"""Construct a _BindParamClause.
key
@@ -2154,9 +2156,10 @@ class _BindParamClause(ColumnElement):
self.required = required
if type_ is None:
- self.type = sqltypes.type_map.get(type(value), _fallback_type or sqltypes.NULLTYPE)
- if _fallback_type and _fallback_type._type_affinity == self.type._type_affinity:
- self.type = _fallback_type
+ if _compared_to_type is not None:
+ self.type = _compared_to_type._coerce_compared_value(_compared_to_operator, value)
+ else:
+ self.type = sqltypes.NULLTYPE
elif isinstance(type_, type):
self.type = type_()
else:
@@ -2434,9 +2437,9 @@ class _Tuple(ClauseList, ColumnElement):
def _select_iterable(self):
return (self, )
- def _bind_param(self, obj):
+ def _bind_param(self, operator, obj):
return _Tuple(*[
- _BindParamClause(None, o, _fallback_type=self.type, unique=True)
+ _BindParamClause(None, o, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
for o in obj
]).self_group()
@@ -2538,8 +2541,8 @@ class FunctionElement(Executable, ColumnElement, FromClause):
def execute(self):
return self.select().execute()
- def _bind_param(self, obj):
- return _BindParamClause(None, obj, _fallback_type=self.type, unique=True)
+ def _bind_param(self, operator, obj):
+ return _BindParamClause(None, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
class Function(FunctionElement):
@@ -2555,8 +2558,8 @@ class Function(FunctionElement):
FunctionElement.__init__(self, *clauses, **kw)
- def _bind_param(self, obj):
- return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True)
+ def _bind_param(self, operator, obj):
+ return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
class _Cast(ColumnElement):
@@ -3165,8 +3168,8 @@ class ColumnClause(_Immutable, ColumnElement):
else:
return []
- def _bind_param(self, obj):
- return _BindParamClause(self.name, obj, _fallback_type=self.type, unique=True)
+ def _bind_param(self, operator, obj):
+ return _BindParamClause(self.name, obj, _compared_to_operator=operator, _compared_to_type=self.type, unique=True)
def _make_proxy(self, selectable, name=None, attach=True):
# propagate the "is_literal" flag only if we are keeping our name,