diff options
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c113edaa3..6f51ccbe9 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -658,6 +658,17 @@ class ColumnElement(Selectable, CompareMixin): else: return self +class ColumnCollection(util.OrderedProperties): + def add(self, column): + self[column.key] = column + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c==local) + return and_(*l) + class FromClause(Selectable): """represents an element that can be used within the FROM clause of a SELECT statement.""" def __init__(self, name=None): @@ -671,7 +682,7 @@ class FromClause(Selectable): visitor.visit_fromclause(self) def count(self, whereclause=None, **params): if len(self.primary_key): - col = self.primary_key[0] + col = list(self.primary_key)[0] else: col = list(self.columns)[0] return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) @@ -735,8 +746,8 @@ class FromClause(Selectable): if hasattr(self, '_columns'): # TODO: put a mutex here ? this is a key place for threading probs return - self._columns = util.OrderedProperties() - self._primary_key = [] + self._columns = ColumnCollection() + self._primary_key = ColumnCollection() self._foreign_keys = util.Set() self._orig_cols = {} export = self._exportable_columns() @@ -1082,7 +1093,7 @@ class Join(FromClause): def _proxy_column(self, column): self._columns[column._label] = column if column.primary_key: - self._primary_key.append(column) + self._primary_key.add(column) for f in column.foreign_keys: self._foreign_keys.add(f) return column @@ -1257,9 +1268,9 @@ class TableClause(FromClause): def __init__(self, name, *columns): super(TableClause, self).__init__(name) self.name = self.fullname = name - self._columns = util.OrderedProperties() + self._columns = ColumnCollection() self._foreign_keys = util.Set() - self._primary_key = [] + self._primary_key = util.Set() for c in columns: self.append_column(c) self._oid_column = ColumnClause('oid', self, hidden=True) @@ -1282,16 +1293,6 @@ class TableClause(FromClause): return self._orig_cols original_columns = property(_orig_columns) - def _clear(self): - """clears all attributes on this TableClause so that new items can be added again""" - self.columns.clear() - self.foreign_keys[:] = [] - self.primary_key[:] = [] - try: - delattr(self, '_orig_cols') - except AttributeError: - pass - def accept_visitor(self, visitor): visitor.visit_table(self) def _exportable_columns(self): @@ -1305,7 +1306,7 @@ class TableClause(FromClause): data[self] = self def count(self, whereclause=None, **params): if len(self.primary_key): - col = self.primary_key[0] + col = list(self.primary_key)[0] else: col = list(self.columns)[0] return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) |