diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-11 16:04:38 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-08-11 16:04:38 +0000 |
commit | ac219b0192814cea0611f7251f7bb3927e5c3201 (patch) | |
tree | 1a64972749c52d02ed5224d1d0214b83d6185e50 /lib/sqlalchemy/ansisql.py | |
parent | e8793a5b59a05fb1d96c228bcd2e9f3ec381c0b4 (diff) | |
download | sqlalchemy-ac219b0192814cea0611f7251f7bb3927e5c3201.tar.gz |
- removed _calculate_correlations() methods, removed correlation_stack, select_stack;
all are merged into a single stack thats all within ansicompiler. clause visiting cut down
significantly.
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 97 |
1 files changed, 59 insertions, 38 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 14bae1d17..bfb08d337 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -150,12 +150,9 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): # actually present in the generated SQL self.bind_names = {} - # when the compiler visits a SELECT statement, the clause object is appended - # to this stack. various visit operations will check this stack to determine - # additional choices (TODO: it seems to be all typemap stuff. shouldnt this only - # apply to the topmost-level SELECT statement ?) - self.select_stack = [] - + # a stack. what recursive compiler doesn't have a stack ? :) + self.stack = [] + # a dictionary of result-set column names (strings) to TypeEngine instances, # which will be passed to a ResultProxy and used for resultset-level value conversion self.typemap = {} @@ -184,13 +181,6 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): # an ANSIIdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer - # a dictionary containing attributes about all select() - # elements located within the clause, regarding which are subqueries, which are - # selected from, and which elements should be correlated to an enclosing select. - # used mostly to determine the list of FROM elements for each select statement, as well - # as some dialect-specific rules regarding subqueries. - self.correlate_state = {} - # for UPDATE and INSERT statements, a set of columns whos values are being set # from a SQL expression (i.e., not one of the bind parameter values). if present, # default-value logic in the Dialect knows not to fire off column defaults @@ -230,11 +220,17 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): self.string = self.process(self.statement) self.after_compile() - def process(self, obj, **kwargs): - return self.traverse_single(obj, **kwargs) + def process(self, obj, stack=None, **kwargs): + if stack: + self.stack.append(stack) + try: + return self.traverse_single(obj, **kwargs) + finally: + if stack: + self.stack.pop(-1) def is_subquery(self, select): - return self.correlate_state[select].get('is_subquery', False) + return self.stack and self.stack[-1].get('is_subquery') def get_whereclause(self, obj): """given a FROM clause, return an additional WHERE condition that should be @@ -292,7 +288,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): def visit_label(self, label): labelname = self._truncated_identifier("colident", label.name) - if self.select_stack: + if self.stack and self.stack[-1].get('select'): self.typemap.setdefault(labelname.lower(), label.obj.type) if isinstance(label.obj, sql._ColumnClause): self.column_labels[label.obj._label] = labelname @@ -310,7 +306,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): else: name = column.name - if self.select_stack: + if self.stack and self.stack[-1].get('select'): # if we are within a visit to a Select, set up the "typemap" # for this column which is used to translate result set values self.typemap.setdefault(name.lower(), column.type) @@ -369,28 +365,28 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return self.process(clause.clause_expr) def visit_cast(self, cast, **kwargs): - if self.select_stack: + if self.stack and self.stack[-1].get('select'): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) def visit_function(self, func, **kwargs): - if self.select_stack: + if self.stack and self.stack[-1].get('select'): self.typemap.setdefault(func.name, func.type) if not self.apply_function_parens(func): return ".".join(func.packagenames + [func.name]) else: return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) - def visit_compound_select(self, cs, asfrom=False, **kwargs): - text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ") - group_by = self.process(cs._group_by_clause) + def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): + text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ") + group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by text += self.order_by_clause(cs) text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" - if asfrom: + if asfrom and parens: return "(" + text + ")" else: return text @@ -499,7 +495,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): # names look like table.colname. so if column is in a "selected from" # subquery, label it synoymously with its column name if \ - self.correlate_state[select].get('is_selected_from', False) and \ + (self.stack and self.stack[-1].get('is_selected_from')) and \ isinstance(column, sql._ColumnClause) and \ not column.is_literal and \ column.table is not None and \ @@ -507,16 +503,37 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return column.label(column.name) else: return None - - def visit_select(self, select, asfrom=False, **kwargs): - select._calculate_correlations(self.correlate_state) - self.select_stack.append(select) + def visit_select(self, select, asfrom=False, parens=True, **kwargs): + + stack_entry = {'select':select} + + if asfrom: + stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True + elif self.stack and self.stack[-1].get('select'): + stack_entry['is_subquery'] = True + + if self.stack and self.stack[-1].get('from'): + existingfroms = self.stack[-1]['from'] + else: + existingfroms = None + froms = select._get_display_froms(existingfroms) + + correlate_froms = util.Set() + for f in froms: + correlate_froms.add(f) + for f2 in f._get_from_objects(): + correlate_froms.add(f2) + + # TODO: might want to propigate existing froms for select(select(select)) + # where innermost select should correlate to outermost +# if existingfroms: +# correlate_froms = correlate_froms.union(existingfroms) + stack_entry['from'] = correlate_froms + self.stack.append(stack_entry) # the actual list of columns to print in the SELECT column list. inner_columns = util.OrderedSet() - - froms = select._get_display_froms(self.correlate_state) for co in select.inner_columns: if select.use_labels: @@ -533,9 +550,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): inner_columns.add(self.process(l)) else: inner_columns.add(self.process(co)) - - self.select_stack.pop(-1) - + collist = string.join(inner_columns.difference(util.Set([None])), ', ') text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " @@ -579,7 +594,9 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) - if asfrom: + self.stack.pop(-1) + + if asfrom and parens: return "(" + text + ")" else: return text @@ -652,7 +669,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")") def visit_update(self, update_stmt): - update_stmt._calculate_correlations(self.correlate_state) + self.stack.append({'from':util.Set([update_stmt.table])}) # search for columns who will be required to have an explicit bound value. # for updates, this includes Python-side "onupdate" defaults. @@ -672,7 +689,9 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) - + + self.stack.pop(-1) + return text def _get_colparams(self, stmt, required_cols): @@ -735,13 +754,15 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return values def visit_delete(self, delete_stmt): - delete_stmt._calculate_correlations(self.correlate_state) + self.stack.append({'from':util.Set([delete_stmt.table])}) text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) if delete_stmt._whereclause: text += " WHERE " + self.process(delete_stmt._whereclause) + self.stack.pop(-1) + return text def visit_savepoint(self, savepoint_stmt): |