diff options
author | Philip Jenvey <pjenvey@underboss.org> | 2009-09-24 02:11:56 +0000 |
---|---|---|
committer | Philip Jenvey <pjenvey@underboss.org> | 2009-09-24 02:11:56 +0000 |
commit | 5a9c1b8824bb84aaf8baccdfa2780a94af5c0f44 (patch) | |
tree | abb0eed7f59567b73b0087d2f1e68c89254f7d2a /lib/sqlalchemy/sql | |
parent | 79ce8e89bd0537d26c8c3594557b2aa4c67f8f90 (diff) | |
download | sqlalchemy-5a9c1b8824bb84aaf8baccdfa2780a94af5c0f44.tar.gz |
merge from branches/clauseelement-nonzero
adds a __nonzero__ to _BinaryExpression to avoid faulty comparisons during hash
collisions (which only occur on Jython)
fixes #1547
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 85 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 4 |
4 files changed, 62 insertions, 51 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 7bd0c1b05..b4b901067 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -368,11 +368,11 @@ class SQLCompiler(engine.Compiled): def visit_case(self, clause, **kwargs): x = "CASE " - if clause.value: + if clause.value is not None: x += self.process(clause.value) + " " for cond, result in clause.whens: x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " " - if clause.else_: + if clause.else_ is not None: x += "ELSE " + self.process(clause.else_) + " " x += "END" return x @@ -538,7 +538,7 @@ class SQLCompiler(engine.Compiled): if isinstance(column, sql._Label): return column - if select and select.use_labels and column._label: + if select is not None and select.use_labels and column._label: return _CompileLabel(column, column._label) if \ @@ -741,7 +741,7 @@ class SQLCompiler(engine.Compiled): if self.returning_precedes_values: text += " " + self.returning_clause(update_stmt, update_stmt._returning) - if update_stmt._whereclause: + if update_stmt._whereclause is not None: text += " WHERE " + self.process(update_stmt._whereclause) if self.returning and not self.returning_precedes_values: @@ -891,7 +891,7 @@ class SQLCompiler(engine.Compiled): if self.returning_precedes_values: text += " " + self.returning_clause(delete_stmt, delete_stmt._returning) - if delete_stmt._whereclause: + if delete_stmt._whereclause is not None: text += " WHERE " + self.process(delete_stmt._whereclause) if self.returning and not self.returning_precedes_values: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 1b9ae6e8f..0ece67e20 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -974,7 +974,7 @@ def _no_literals(element): def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) - if not c: + if c is None: raise exc.InvalidRequestError( "Given column '%s', attached to table '%s', " "failed to locate a corresponding column from table '%s'" @@ -1044,7 +1044,18 @@ class ClauseElement(Visitable): d = self.__dict__.copy() d.pop('_is_clone_of', None) return d - + + if util.jython: + def __hash__(self): + """Return a distinct hash code. + + ClauseElements may have special equality comparisons which + makes us rely on them having unique hash codes for use in + hash-based collections. Stock __hash__ doesn't guarantee + unique values on platforms with moving GCs. + """ + return id(self) + def _annotate(self, values): """return a copy of this ClauseElement with the given annotations dictionary. @@ -1264,16 +1275,8 @@ class ClauseElement(Visitable): def __invert__(self): return self._negate() - if util.jython: - def __hash__(self): - """Return a distinct hash code. - - ClauseElements may have special equality comparisons which - makes us rely on them having unique hash codes for use in - hash-based collections. Stock __hash__ doesn't guarantee - unique values on platforms with moving GCs. - """ - return id(self) + def __nonzero__(self): + raise TypeError("Boolean value of this clause is not defined") def _negate(self): if hasattr(self, 'negation_clause'): @@ -1797,10 +1800,9 @@ class ColumnCollection(util.OrderedProperties): # column names in their exported columns collection existing = self[key] if not existing.shares_lineage(value): - table = getattr(existing, 'table', None) and existing.table.description util.warn(("Column %r on table %r being replaced by another " "column with the same key. Consider use_labels " - "for select() statements.") % (key, table)) + "for select() statements.") % (key, getattr(existing, 'table', None))) util.OrderedProperties.__setitem__(self, key, value) def remove(self, column): @@ -2343,7 +2345,7 @@ class _Case(ColumnElement): except TypeError: pass - if value: + if value is not None: whenlist = [ (_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens ] @@ -2370,19 +2372,19 @@ class _Case(ColumnElement): self.else_ = None def _copy_internals(self, clone=_clone): - if self.value: + if self.value is not None: self.value = clone(self.value) self.whens = [(clone(x), clone(y)) for x, y in self.whens] - if self.else_: + if self.else_ is not None: self.else_ = clone(self.else_) def get_children(self, **kwargs): - if self.value: + if self.value is not None: yield self.value for x, y in self.whens: yield x yield y - if self.else_: + if self.else_ is not None: yield self.else_ @property @@ -2548,7 +2550,13 @@ class _BinaryExpression(ColumnElement): self.modifiers = {} else: self.modifiers = modifiers - + + def __nonzero__(self): + try: + return self.operator(hash(self.left), hash(self.right)) + except: + raise TypeError("Boolean value of this clause is not defined") + @property def _from_objects(self): return self.left._from_objects + self.right._from_objects @@ -3017,7 +3025,7 @@ class ColumnClause(_Immutable, ColumnElement): if self.is_literal: return None - elif self.table and self.table.named_with_column: + elif self.table is not None and self.table.named_with_column: if getattr(self.table, 'schema', None): label = self.table.schema + "_" + \ _escape_for_generated(self.table.name) + "_" + \ @@ -3047,7 +3055,7 @@ class ColumnClause(_Immutable, ColumnElement): @property def _from_objects(self): - if self.table: + if self.table is not None: return [self.table] else: return [] @@ -3264,7 +3272,7 @@ class _SelectBaseMixin(object): if len(clauses) == 1 and clauses[0] is None: self._order_by_clause = ClauseList() else: - if getattr(self, '_order_by_clause', None): + if getattr(self, '_order_by_clause', None) is not None: clauses = list(self._order_by_clause) + list(clauses) self._order_by_clause = ClauseList(*clauses) @@ -3277,7 +3285,7 @@ class _SelectBaseMixin(object): if len(clauses) == 1 and clauses[0] is None: self._group_by_clause = ClauseList() else: - if getattr(self, '_group_by_clause', None): + if getattr(self, '_group_by_clause', None) is not None: clauses = list(self._group_by_clause) + list(clauses) self._group_by_clause = ClauseList(*clauses) @@ -3433,28 +3441,31 @@ class Select(_SelectBaseMixin, FromClause): self._froms = util.OrderedSet() if columns: - self._raw_columns = [ - isinstance(c, _ScalarSelect) and - c.self_group(against=operators.comma_op) or c - for c in [_literal_as_column(c) for c in columns] - ] + self._raw_columns = [] + for c in columns: + c = _literal_as_column(c) + if isinstance(c, _ScalarSelect): + c = c.self_group(against=operators.comma_op) + self._raw_columns.append(c) self._froms.update(_from_objects(*self._raw_columns)) else: self._raw_columns = [] - if whereclause: + if whereclause is not None: self._whereclause = _literal_as_text(whereclause) self._froms.update(_from_objects(self._whereclause)) else: self._whereclause = None - if from_obj: - self._froms.update( - _is_literal(f) and _TextClause(f) or f - for f in util.to_list(from_obj)) + if from_obj is not None: + for f in util.to_list(from_obj): + if _is_literal(f): + self._froms.add(_TextClause(f)) + else: + self._froms.add(f) - if having: + if having is not None: self._having = _literal_as_text(having) else: self._having = None @@ -3977,7 +3988,7 @@ class Update(_ValuesBase): _ValuesBase.__init__(self, table, values) self._bind = bind self._returning = returning - if whereclause: + if whereclause is not None: self._whereclause = _literal_as_text(whereclause) else: self._whereclause = None @@ -4027,7 +4038,7 @@ class Delete(_UpdateBase): self.table = table self._returning = returning - if whereclause: + if whereclause is not None: self._whereclause = _literal_as_text(whereclause) else: self._whereclause = None diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 9be405e21..02165aad5 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -103,7 +103,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False): else: raise - if col: + if col is not None: crit.append(col == fk.parent) constraints.add(fk.constraint) if a is not b: @@ -116,7 +116,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False): else: raise - if col: + if col is not None: crit.append(col == fk.parent) constraints.add(fk.constraint) @@ -267,9 +267,9 @@ def splice_joins(left, right, stop_on=None): stack.append((right.left, right)) else: right = adapter.traverse(right) - if prevright: + if prevright is not None: prevright.left = right - if not ret: + if ret is None: ret = right return ret @@ -467,10 +467,10 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded) - if not newcol and col in self.equivalents and col not in _seen: + if newcol is None and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: newcol = self._corresponding_column(equiv, require_embedded=require_embedded, _seen=_seen.union([col])) - if newcol: + if newcol is not None: return newcol return newcol @@ -525,7 +525,7 @@ class ColumnAdapter(ClauseAdapter): def _locate_col(self, col): c = self._corresponding_column(col, False) - if not c: + if c is None: c = self.adapt_clause(col) # anonymize labels in case they have a hardcoded name diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 4471d4fb0..4a54375f8 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -150,7 +150,7 @@ class ReplacingCloningVisitor(CloningVisitor): def replace(elem): for v in self._visitor_iterator: e = v.replace(elem) - if e: + if e is not None: return e return replacement_traverse(obj, self.__traverse_options__, replace) @@ -236,7 +236,7 @@ def replacement_traverse(obj, opts, replace): def clone(element): newelem = replace(element) - if newelem: + if newelem is not None: stop_on.add(newelem) return newelem |