diff options
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 46 |
1 files changed, 33 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c3d0f9de0..280ebd81c 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -914,25 +914,45 @@ class _CompareMixin(object): """produce an ``IN`` clause.""" if len(other) == 0: return self.__eq__(None) - elif len(other) == 1 and not hasattr(other[0], '_selectable'): - return self.__eq__(other[0]) - elif _is_literal(other[0]): - return self._compare('IN', ClauseList(parens=True, *[self._bind_param(o) for o in other]), negate='NOT IN') - else: - # assume *other is a single select. - # originally, this assumed possibly multiple selects and created a UNION, - # but we are now forcing explictness if a UNION is desired. - if len(other) > 1: - raise exceptions.InvalidRequestException("in() function accepts only multiple literal values, or a single selectable as an argument") - return self._compare('IN', other[0], negate='NOT IN') + elif len(other) == 1: + o = other[0] + if _is_literal(o) or isinstance( o, _CompareMixin): + return self.__eq__( o) #single item -> == + else: + assert hasattr( o, '_selectable') #better check? + return self._compare( 'IN', o, negate='NOT IN') #single selectable + + args = [] + for o in other: + if not _is_literal(o): + if not isinstance( o, _CompareMixin): + raise exceptions.InvalidRequestError( "in() function accepts either non-selectable values, or a single selectable: "+repr(o) ) + else: + o = self._bind_param(o) + args.append(o) + return self._compare( 'IN', ClauseList( parens=True, *args), negate='NOT IN') def startswith(self, other): """produce the clause ``LIKE '<other>%'``""" - return self._compare('LIKE', other + "%") + perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String) + return self._compare('LIKE', other + perc) def endswith(self, other): """produce the clause ``LIKE '%<other>'``""" - return self._compare('LIKE', "%" + other) + if isinstance(other,(str,unicode)): po = '%' + other + else: + po = literal('%', type= sqltypes.String) + other + po.type = sqltypes.to_instance( sqltypes.String) #force! + return self._compare('LIKE', po) + + def __radd__(self, other): + return self._bind_param(other)._operate('+', self) + def __rsub__(self, other): + return self._bind_param(other)._operate('-', self) + def __rmul__(self, other): + return self._bind_param(other)._operate('*', self) + def __rdiv__(self, other): + return self._bind_param(other)._operate('/', self) def label(self, name): """produce a column label, i.e. ``<columnname> AS <name>``""" |