diff options
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r-- | lib/sqlalchemy/sql.py | 50 |
1 files changed, 37 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 9b3571384..cf42b2e83 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -522,6 +522,7 @@ class ColumnElement(Selectable, CompareMixin): primary_key = property(lambda self:getattr(self, '_primary_key', False)) foreign_key = property(lambda self:getattr(self, '_foreign_key', False)) original = property(lambda self:getattr(self, '_original', self)) + parent = property(lambda self:getattr(self, '_parent', self)) columns = property(lambda self:[self]) def _make_proxy(self, selectable, name=None): """creates a new ColumnElement representing this ColumnElement as it appears in the select list @@ -563,12 +564,19 @@ class FromClause(Selectable): return Join(self, right, isouter = True, *args, **kwargs) def alias(self, name=None): return Alias(self, name) - def _get_col_by_original(self, column): + def _get_col_by_original(self, column, raiseerr=True): """given a column which is a schema.Column object attached to a schema.Table object (i.e. an "original" column), return the Column object from this Selectable which corresponds to that original Column, or None if this Selectable does not contain the column.""" - return self.original_columns.get(column.original, None) + try: + return self.original_columns[column.original] + except KeyError: + if not raiseerr: + return None + else: + raise InvalidRequestError("cant get orig for " + str(column) + " with table " + column.table.id + " from table " + self.id) + def _get_exported_attribute(self, name): try: return getattr(self, name) @@ -595,6 +603,8 @@ class FromClause(Selectable): for co in column.columns: cp = self._proxy_column(co) self._orig_cols[co.original] = cp + if getattr(self, 'oid_column', None): + self._orig_cols[self.oid_column.original] = self.oid_column def _exportable_columns(self): return [] def _proxy_column(self, column): @@ -699,6 +709,8 @@ class ClauseList(ClauseElement): self.clauses.append(clause) def accept_visitor(self, visitor): for c in self.clauses: + if c is None: + raise "oh weird" + repr(self.clauses) c.accept_visitor(visitor) visitor.visit_clauselist(self) def _get_from_objects(self): @@ -904,13 +916,17 @@ class Join(FromClause): class Alias(FromClause): def __init__(self, selectable, alias = None): - while isinstance(selectable, Alias): - selectable = selectable.selectable + baseselectable = selectable + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.selectable + self.original = baseselectable self.selectable = selectable if alias is None: - n = getattr(selectable, 'name') + n = getattr(self.original, 'name') if n is None: n = 'anon' + elif len(n) > 15: + n = n[0:15] alias = n + "_" + hex(random.randint(0, 65535))[2:] self.name = alias self.id = self.name @@ -949,6 +965,7 @@ class Label(ColumnElement): key = property(lambda s: s.name) _label = property(lambda s: s.name) original = property(lambda s:s.obj.original) + parent = property(lambda s:s.obj.parent) def accept_visitor(self, visitor): self.obj.accept_visitor(visitor) visitor.visit_label(self) @@ -1009,7 +1026,8 @@ class ColumnImpl(ColumnElement): engine = property(lambda s: s.column.engine) default_label = property(lambda s:s._label) - original = property(lambda self:self.column) + original = property(lambda self:self.column.original) + parent = property(lambda self:self.column.parent) columns = property(lambda self:[self.column]) def label(self, name): @@ -1073,6 +1091,9 @@ class TableImpl(FromClause): self._orig_cols= {} for c in self.columns: self._orig_cols[c.original] = c + oid = self.oid_column + if oid is not None: + self._orig_cols[oid.original] = oid return self._orig_cols oid_column = property(_oid_col) @@ -1132,13 +1153,18 @@ class SelectBaseMixin(object): if not hasattr(self, attribute): l = ClauseList(*clauses) setattr(self, attribute, l) - self.append_clause(prefix, l) else: getattr(self, attribute).clauses += clauses - def append_clause(self, keyword, clause): - if type(clause) == str: - clause = TextClause(clause) - self.clauses.append((keyword, clause)) + def _get_clauses(self): + # TODO: this is a little stupid. make ORDER BY/GROUP BY keywords handled by + # the compiler, make group_by_clause/order_by_clause regular attributes + x =[] + if getattr(self, 'group_by_clause', None): + x.append(("GROUP BY", self.group_by_clause)) + if getattr(self, 'order_by_clause', None): + x.append(("ORDER BY", self.order_by_clause)) + return x + clauses = property(_get_clauses) def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) def _get_from_objects(self): @@ -1157,7 +1183,6 @@ class CompoundSelect(SelectBaseMixin, FromClause): for s in self.selects: s.group_by(None) s.order_by(None) - self.clauses = [] group_by = kwargs.get('group_by', None) if group_by: self.group_by(*group_by) @@ -1211,7 +1236,6 @@ class Select(SelectBaseMixin, FromClause): # indicates if this select statement is a subquery as a criterion # inside of a WHERE clause self.is_where = False - self.clauses = [] self.distinct = distinct self._text = None |