summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r--lib/sqlalchemy/sql.py35
1 files changed, 18 insertions, 17 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index c113edaa3..6f51ccbe9 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -658,6 +658,17 @@ class ColumnElement(Selectable, CompareMixin):
else:
return self
+class ColumnCollection(util.OrderedProperties):
+ def add(self, column):
+ self[column.key] = column
+ def __eq__(self, other):
+ l = []
+ for c in other:
+ for local in self:
+ if c.shares_lineage(local):
+ l.append(c==local)
+ return and_(*l)
+
class FromClause(Selectable):
"""represents an element that can be used within the FROM clause of a SELECT statement."""
def __init__(self, name=None):
@@ -671,7 +682,7 @@ class FromClause(Selectable):
visitor.visit_fromclause(self)
def count(self, whereclause=None, **params):
if len(self.primary_key):
- col = self.primary_key[0]
+ col = list(self.primary_key)[0]
else:
col = list(self.columns)[0]
return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
@@ -735,8 +746,8 @@ class FromClause(Selectable):
if hasattr(self, '_columns'):
# TODO: put a mutex here ? this is a key place for threading probs
return
- self._columns = util.OrderedProperties()
- self._primary_key = []
+ self._columns = ColumnCollection()
+ self._primary_key = ColumnCollection()
self._foreign_keys = util.Set()
self._orig_cols = {}
export = self._exportable_columns()
@@ -1082,7 +1093,7 @@ class Join(FromClause):
def _proxy_column(self, column):
self._columns[column._label] = column
if column.primary_key:
- self._primary_key.append(column)
+ self._primary_key.add(column)
for f in column.foreign_keys:
self._foreign_keys.add(f)
return column
@@ -1257,9 +1268,9 @@ class TableClause(FromClause):
def __init__(self, name, *columns):
super(TableClause, self).__init__(name)
self.name = self.fullname = name
- self._columns = util.OrderedProperties()
+ self._columns = ColumnCollection()
self._foreign_keys = util.Set()
- self._primary_key = []
+ self._primary_key = util.Set()
for c in columns:
self.append_column(c)
self._oid_column = ColumnClause('oid', self, hidden=True)
@@ -1282,16 +1293,6 @@ class TableClause(FromClause):
return self._orig_cols
original_columns = property(_orig_columns)
- def _clear(self):
- """clears all attributes on this TableClause so that new items can be added again"""
- self.columns.clear()
- self.foreign_keys[:] = []
- self.primary_key[:] = []
- try:
- delattr(self, '_orig_cols')
- except AttributeError:
- pass
-
def accept_visitor(self, visitor):
visitor.visit_table(self)
def _exportable_columns(self):
@@ -1305,7 +1306,7 @@ class TableClause(FromClause):
data[self] = self
def count(self, whereclause=None, **params):
if len(self.primary_key):
- col = self.primary_key[0]
+ col = list(self.primary_key)[0]
else:
col = list(self.columns)[0]
return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)