summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ansisql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-08-11 16:04:38 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-08-11 16:04:38 +0000
commitac219b0192814cea0611f7251f7bb3927e5c3201 (patch)
tree1a64972749c52d02ed5224d1d0214b83d6185e50 /lib/sqlalchemy/ansisql.py
parente8793a5b59a05fb1d96c228bcd2e9f3ec381c0b4 (diff)
downloadsqlalchemy-ac219b0192814cea0611f7251f7bb3927e5c3201.tar.gz
- removed _calculate_correlations() methods, removed correlation_stack, select_stack;
all are merged into a single stack thats all within ansicompiler. clause visiting cut down significantly.
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r--lib/sqlalchemy/ansisql.py97
1 files changed, 59 insertions, 38 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 14bae1d17..bfb08d337 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -150,12 +150,9 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
# actually present in the generated SQL
self.bind_names = {}
- # when the compiler visits a SELECT statement, the clause object is appended
- # to this stack. various visit operations will check this stack to determine
- # additional choices (TODO: it seems to be all typemap stuff. shouldnt this only
- # apply to the topmost-level SELECT statement ?)
- self.select_stack = []
-
+ # a stack. what recursive compiler doesn't have a stack ? :)
+ self.stack = []
+
# a dictionary of result-set column names (strings) to TypeEngine instances,
# which will be passed to a ResultProxy and used for resultset-level value conversion
self.typemap = {}
@@ -184,13 +181,6 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
# an ANSIIdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
- # a dictionary containing attributes about all select()
- # elements located within the clause, regarding which are subqueries, which are
- # selected from, and which elements should be correlated to an enclosing select.
- # used mostly to determine the list of FROM elements for each select statement, as well
- # as some dialect-specific rules regarding subqueries.
- self.correlate_state = {}
-
# for UPDATE and INSERT statements, a set of columns whos values are being set
# from a SQL expression (i.e., not one of the bind parameter values). if present,
# default-value logic in the Dialect knows not to fire off column defaults
@@ -230,11 +220,17 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
self.string = self.process(self.statement)
self.after_compile()
- def process(self, obj, **kwargs):
- return self.traverse_single(obj, **kwargs)
+ def process(self, obj, stack=None, **kwargs):
+ if stack:
+ self.stack.append(stack)
+ try:
+ return self.traverse_single(obj, **kwargs)
+ finally:
+ if stack:
+ self.stack.pop(-1)
def is_subquery(self, select):
- return self.correlate_state[select].get('is_subquery', False)
+ return self.stack and self.stack[-1].get('is_subquery')
def get_whereclause(self, obj):
"""given a FROM clause, return an additional WHERE condition that should be
@@ -292,7 +288,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
def visit_label(self, label):
labelname = self._truncated_identifier("colident", label.name)
- if self.select_stack:
+ if self.stack and self.stack[-1].get('select'):
self.typemap.setdefault(labelname.lower(), label.obj.type)
if isinstance(label.obj, sql._ColumnClause):
self.column_labels[label.obj._label] = labelname
@@ -310,7 +306,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
else:
name = column.name
- if self.select_stack:
+ if self.stack and self.stack[-1].get('select'):
# if we are within a visit to a Select, set up the "typemap"
# for this column which is used to translate result set values
self.typemap.setdefault(name.lower(), column.type)
@@ -369,28 +365,28 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
return self.process(clause.clause_expr)
def visit_cast(self, cast, **kwargs):
- if self.select_stack:
+ if self.stack and self.stack[-1].get('select'):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
def visit_function(self, func, **kwargs):
- if self.select_stack:
+ if self.stack and self.stack[-1].get('select'):
self.typemap.setdefault(func.name, func.type)
if not self.apply_function_parens(func):
return ".".join(func.packagenames + [func.name])
else:
return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
- def visit_compound_select(self, cs, asfrom=False, **kwargs):
- text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ")
- group_by = self.process(cs._group_by_clause)
+ def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
+ text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ")
+ group_by = self.process(cs._group_by_clause, asfrom=asfrom)
if group_by:
text += " GROUP BY " + group_by
text += self.order_by_clause(cs)
text += (cs._limit or cs._offset) and self.limit_clause(cs) or ""
- if asfrom:
+ if asfrom and parens:
return "(" + text + ")"
else:
return text
@@ -499,7 +495,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
# names look like table.colname. so if column is in a "selected from"
# subquery, label it synoymously with its column name
if \
- self.correlate_state[select].get('is_selected_from', False) and \
+ (self.stack and self.stack[-1].get('is_selected_from')) and \
isinstance(column, sql._ColumnClause) and \
not column.is_literal and \
column.table is not None and \
@@ -507,16 +503,37 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
return column.label(column.name)
else:
return None
-
- def visit_select(self, select, asfrom=False, **kwargs):
- select._calculate_correlations(self.correlate_state)
- self.select_stack.append(select)
+ def visit_select(self, select, asfrom=False, parens=True, **kwargs):
+
+ stack_entry = {'select':select}
+
+ if asfrom:
+ stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+ elif self.stack and self.stack[-1].get('select'):
+ stack_entry['is_subquery'] = True
+
+ if self.stack and self.stack[-1].get('from'):
+ existingfroms = self.stack[-1]['from']
+ else:
+ existingfroms = None
+ froms = select._get_display_froms(existingfroms)
+
+ correlate_froms = util.Set()
+ for f in froms:
+ correlate_froms.add(f)
+ for f2 in f._get_from_objects():
+ correlate_froms.add(f2)
+
+ # TODO: might want to propigate existing froms for select(select(select))
+ # where innermost select should correlate to outermost
+# if existingfroms:
+# correlate_froms = correlate_froms.union(existingfroms)
+ stack_entry['from'] = correlate_froms
+ self.stack.append(stack_entry)
# the actual list of columns to print in the SELECT column list.
inner_columns = util.OrderedSet()
-
- froms = select._get_display_froms(self.correlate_state)
for co in select.inner_columns:
if select.use_labels:
@@ -533,9 +550,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
inner_columns.add(self.process(l))
else:
inner_columns.add(self.process(co))
-
- self.select_stack.pop(-1)
-
+
collist = string.join(inner_columns.difference(util.Set([None])), ', ')
text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " "
@@ -579,7 +594,9 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
text += (select._limit or select._offset) and self.limit_clause(select) or ""
text += self.for_update_clause(select)
- if asfrom:
+ self.stack.pop(-1)
+
+ if asfrom and parens:
return "(" + text + ")"
else:
return text
@@ -652,7 +669,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
" VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
def visit_update(self, update_stmt):
- update_stmt._calculate_correlations(self.correlate_state)
+ self.stack.append({'from':util.Set([update_stmt.table])})
# search for columns who will be required to have an explicit bound value.
# for updates, this includes Python-side "onupdate" defaults.
@@ -672,7 +689,9 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
-
+
+ self.stack.pop(-1)
+
return text
def _get_colparams(self, stmt, required_cols):
@@ -735,13 +754,15 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
return values
def visit_delete(self, delete_stmt):
- delete_stmt._calculate_correlations(self.correlate_state)
+ self.stack.append({'from':util.Set([delete_stmt.table])})
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
if delete_stmt._whereclause:
text += " WHERE " + self.process(delete_stmt._whereclause)
+ self.stack.pop(-1)
+
return text
def visit_savepoint(self, savepoint_stmt):